diff --git a/.github/workflows/code-quality-checks.yml b/.github/workflows/code-quality-checks.yml index 22db995c5..2d49707a5 100644 --- a/.github/workflows/code-quality-checks.yml +++ b/.github/workflows/code-quality-checks.yml @@ -35,6 +35,7 @@ jobs: - name: Install Poetry uses: snok/install-poetry@v1 with: + version: "2.2.1" virtualenvs-create: true virtualenvs-in-project: true installer-parallel: true @@ -106,10 +107,10 @@ jobs: # check-out repo and set-up python #---------------------------------------------- - name: Check out repository - uses: actions/checkout@v2 + uses: actions/checkout@v4 - name: Set up python ${{ matrix.python-version }} id: setup-python - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} #---------------------------------------------- @@ -118,6 +119,7 @@ jobs: - name: Install Poetry uses: snok/install-poetry@v1 with: + version: "2.2.1" virtualenvs-create: true virtualenvs-in-project: true installer-parallel: true @@ -191,6 +193,7 @@ jobs: - name: Install Poetry uses: snok/install-poetry@v1 with: + version: "2.2.1" virtualenvs-create: true virtualenvs-in-project: true installer-parallel: true @@ -243,6 +246,7 @@ jobs: - name: Install Poetry uses: snok/install-poetry@v1 with: + version: "2.2.1" virtualenvs-create: true virtualenvs-in-project: true installer-parallel: true diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 127c8ff4f..cf62d35d8 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -33,6 +33,7 @@ jobs: - name: Install Poetry uses: snok/install-poetry@v1 with: + version: "2.2.1" virtualenvs-create: true virtualenvs-in-project: true installer-parallel: true diff --git a/.github/workflows/publish-test.yml b/.github/workflows/publish-test.yml index 2e6359a78..b7ffee9f4 100644 --- a/.github/workflows/publish-test.yml +++ b/.github/workflows/publish-test.yml @@ -21,6 +21,7 @@ jobs: - name: Install Poetry uses: snok/install-poetry@v1 with: + version: "2.2.1" virtualenvs-create: true virtualenvs-in-project: true installer-parallel: true diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index dde6cc2dc..c592756b8 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -23,6 +23,7 @@ jobs: - name: Install Poetry uses: snok/install-poetry@v1 with: + version: "2.2.1" virtualenvs-create: true virtualenvs-in-project: true installer-parallel: true diff --git a/src/databricks/sql/auth/common.py b/src/databricks/sql/auth/common.py index 3e0be0d2b..44151ecb3 100644 --- a/src/databricks/sql/auth/common.py +++ b/src/databricks/sql/auth/common.py @@ -47,6 +47,7 @@ def __init__( retry_stop_after_attempts_duration: Optional[float] = None, retry_delay_default: Optional[float] = None, retry_dangerous_codes: Optional[List[int]] = None, + respect_server_retry_after_header: Optional[bool] = None, proxy_auth_method: Optional[str] = None, pool_connections: Optional[int] = None, pool_maxsize: Optional[int] = None, @@ -79,6 +80,7 @@ def __init__( ) self.retry_delay_default = retry_delay_default or 5.0 self.retry_dangerous_codes = retry_dangerous_codes or [] + self.respect_server_retry_after_header = bool(respect_server_retry_after_header) self.proxy_auth_method = proxy_auth_method self.pool_connections = pool_connections or 10 self.pool_maxsize = pool_maxsize or 20 diff --git a/src/databricks/sql/auth/retry.py b/src/databricks/sql/auth/retry.py index 4281883da..140e7845b 100755 --- a/src/databricks/sql/auth/retry.py +++ b/src/databricks/sql/auth/retry.py @@ -94,6 +94,7 @@ def __init__( stop_after_attempts_duration: float, delay_default: float, force_dangerous_codes: List[int], + respect_server_retry_after_header: bool = False, urllib3_kwargs: dict = {}, ): # These values do not change from one command to the next @@ -103,6 +104,7 @@ def __init__( self.stop_after_attempts_duration = stop_after_attempts_duration self._delay_default = delay_default self.force_dangerous_codes = force_dangerous_codes + self.respect_server_retry_after_header = respect_server_retry_after_header # the urllib3 kwargs are a mix of configuration (some of which we override) # and counters like `total` or `connect` which may change between successive retries @@ -202,6 +204,7 @@ def new( stop_after_attempts_duration=self.stop_after_attempts_duration, delay_default=self.delay_default, force_dangerous_codes=self.force_dangerous_codes, + respect_server_retry_after_header=self.respect_server_retry_after_header, urllib3_kwargs={}, ) @@ -323,7 +326,9 @@ def get_backoff_time(self) -> float: return proposed_backoff - def should_retry(self, method: str, status_code: int) -> Tuple[bool, str]: + def should_retry( + self, method: str, status_code: int, has_retry_after: bool = False + ) -> Tuple[bool, str]: """This method encapsulates the connector's approach to retries. We always retry a request unless one of these conditions is met: @@ -381,6 +386,15 @@ def should_retry(self, method: str, status_code: int) -> Tuple[bool, str]: if not self._is_method_retryable(method): return False, "Only POST requests are retried" + # When respect_server_retry_after_header is enabled, only retry when the + # server explicitly signals it's safe via a Retry-After header. This prevents + # duplicate side effects for non-idempotent operations. + if self.respect_server_retry_after_header and not has_retry_after: + return ( + False, + "respect_server_retry_after_header mode: no Retry-After header present", + ) + # Request failed with 404 and was a GetOperationStatus. This is not recoverable. Don't retry. if status_code == 404 and self.command_type == CommandType.GET_OPERATION_STATUS: return ( @@ -450,7 +464,7 @@ def is_retry( Logs a debug message if the request will be retried """ - should_retry, msg = self.should_retry(method, status_code) + should_retry, msg = self.should_retry(method, status_code, has_retry_after) if should_retry: logger.debug(msg) diff --git a/src/databricks/sql/backend/sea/utils/http_client.py b/src/databricks/sql/backend/sea/utils/http_client.py index b47f2add2..476bddb17 100644 --- a/src/databricks/sql/backend/sea/utils/http_client.py +++ b/src/databricks/sql/backend/sea/utils/http_client.py @@ -90,6 +90,9 @@ def __init__( ) self._retry_delay_default = kwargs.get("_retry_delay_default", 5.0) self.force_dangerous_codes = kwargs.get("_retry_dangerous_codes", []) + self._respect_server_retry_after_header = kwargs.get( + "_respect_server_retry_after_header", False + ) # Connection pooling settings self.max_connections = kwargs.get("max_connections", 10) @@ -114,6 +117,7 @@ def __init__( stop_after_attempts_duration=self._retry_stop_after_attempts_duration, delay_default=self._retry_delay_default, force_dangerous_codes=self.force_dangerous_codes, + respect_server_retry_after_header=self._respect_server_retry_after_header, urllib3_kwargs=urllib3_kwargs, ) else: diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index d2b10e718..7cfe181bc 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -189,6 +189,9 @@ def __init__( " This behaviour is deprecated and will be removed in a future release." ) self.force_dangerous_codes = kwargs.get("_retry_dangerous_codes", []) + self._respect_server_retry_after_header = kwargs.get( + "_respect_server_retry_after_header", False + ) additional_transport_args = {} @@ -215,6 +218,7 @@ def __init__( stop_after_attempts_duration=self._retry_stop_after_attempts_duration, delay_default=self._retry_delay_default, force_dangerous_codes=self.force_dangerous_codes, + respect_server_retry_after_header=self._respect_server_retry_after_header, urllib3_kwargs=urllib3_kwargs, ) diff --git a/src/databricks/sql/common/unified_http_client.py b/src/databricks/sql/common/unified_http_client.py index 7ccd69c54..72b5597cc 100644 --- a/src/databricks/sql/common/unified_http_client.py +++ b/src/databricks/sql/common/unified_http_client.py @@ -99,6 +99,7 @@ def _setup_pool_managers(self): stop_after_attempts_duration=self.config.retry_stop_after_attempts_duration, delay_default=self.config.retry_delay_default, force_dangerous_codes=self.config.retry_dangerous_codes, + respect_server_retry_after_header=self.config.respect_server_retry_after_header, ) # Initialize the required attributes that DatabricksRetryPolicy expects diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 9f96e8743..93ab980f8 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -919,6 +919,9 @@ def build_client_context(server_hostname: str, version: str, **kwargs): ), retry_delay_default=kwargs.get("_retry_delay_default"), retry_dangerous_codes=kwargs.get("_retry_dangerous_codes"), + respect_server_retry_after_header=kwargs.get( + "_respect_server_retry_after_header" + ), proxy_auth_method=kwargs.get("_proxy_auth_method"), pool_connections=kwargs.get("_pool_connections"), pool_maxsize=kwargs.get("_pool_maxsize"), diff --git a/tests/e2e/common/large_queries_mixin.py b/tests/e2e/common/large_queries_mixin.py index dd7c56996..ad8538f8b 100644 --- a/tests/e2e/common/large_queries_mixin.py +++ b/tests/e2e/common/large_queries_mixin.py @@ -106,11 +106,11 @@ def test_query_with_large_narrow_result_set(self, extra_params): ], ) def test_long_running_query(self, extra_params): - """Incrementally increase query size until it takes at least 3 minutes, + """Incrementally increase query size until it takes at least 2 minutes, and asserts that the query completes successfully. """ minutes = 60 - min_duration = 3 * minutes + min_duration = 2 * minutes duration = -1 scale0 = 10000 @@ -136,5 +136,5 @@ def test_long_running_query(self, extra_params): duration = time.time() - start current_fraction = duration / min_duration print("Took {} s with scale factor={}".format(duration, scale_factor)) - # Extrapolate linearly to reach 3 min and add 50% padding to push over the limit + # Extrapolate linearly to reach 2 min and add 50% padding to push over the limit scale_factor = math.ceil(1.5 * scale_factor / current_fraction) diff --git a/tests/unit/test_retry.py b/tests/unit/test_retry.py index 897a1d111..635b611e2 100644 --- a/tests/unit/test_retry.py +++ b/tests/unit/test_retry.py @@ -7,9 +7,8 @@ class TestRetry: - @pytest.fixture() - def retry_policy(self) -> DatabricksRetryPolicy: - return DatabricksRetryPolicy( + def _make_retry_policy(self, **overrides) -> DatabricksRetryPolicy: + defaults = dict( delay_min=1, delay_max=30, stop_after_attempts_count=3, @@ -17,6 +16,12 @@ def retry_policy(self) -> DatabricksRetryPolicy: delay_default=2, force_dangerous_codes=[], ) + defaults.update(overrides) + return DatabricksRetryPolicy(**defaults) + + @pytest.fixture() + def retry_policy(self) -> DatabricksRetryPolicy: + return self._make_retry_policy() @pytest.fixture() def error_history(self) -> RequestHistory: @@ -83,3 +88,82 @@ def test_excessive_retry_attempts_error(self, t_mock, retry_policy): retry_policy.sleep(HTTPResponse(status=503)) # Internally urllib3 calls the increment function generating a new instance for every retry retry_policy = retry_policy.increment() + + def test_respect_server_retry_after__retries_with_retry_after(self): + """429 + Retry-After header → should retry""" + policy = self._make_retry_policy(respect_server_retry_after_header=True) + policy._retry_start_time = time.time() + policy.command_type = CommandType.OTHER + should_retry, msg = policy.should_retry("POST", 429, has_retry_after=True) + assert should_retry is True + + def test_respect_server_retry_after__no_retry_without_retry_after(self): + """429 without Retry-After header → no retry""" + policy = self._make_retry_policy(respect_server_retry_after_header=True) + policy._retry_start_time = time.time() + policy.command_type = CommandType.OTHER + should_retry, msg = policy.should_retry("POST", 429, has_retry_after=False) + assert should_retry is False + assert "respect_server_retry_after_header" in msg + + def test_respect_server_retry_after__no_retry_503_without_header(self): + """503 without Retry-After header → no retry""" + policy = self._make_retry_policy(respect_server_retry_after_header=True) + policy._retry_start_time = time.time() + policy.command_type = CommandType.OTHER + should_retry, msg = policy.should_retry("POST", 503, has_retry_after=False) + assert should_retry is False + assert "respect_server_retry_after_header" in msg + + def test_respect_server_retry_after__overrides_dangerous_codes(self): + """force_dangerous_codes=[500] + no Retry-After → no retry in respect_server_retry_after_header mode""" + policy = self._make_retry_policy( + force_dangerous_codes=[500], respect_server_retry_after_header=True + ) + policy._retry_start_time = time.time() + policy.command_type = CommandType.EXECUTE_STATEMENT + should_retry, msg = policy.should_retry("POST", 500, has_retry_after=False) + assert should_retry is False + assert "respect_server_retry_after_header" in msg + + def test_respect_server_retry_after__non_retryable_codes_unaffected(self): + """401/403/501 still don't retry even with Retry-After header""" + policy = self._make_retry_policy(respect_server_retry_after_header=True) + policy._retry_start_time = time.time() + policy.command_type = CommandType.OTHER + for code in [401, 403, 501]: + should_retry, msg = policy.should_retry( + "POST", code, has_retry_after=True + ) + assert should_retry is False, f"Code {code} should never retry" + + def test_default_mode_unchanged(self, retry_policy): + """respect_server_retry_after_header=False preserves existing behavior — 429 retries without header""" + retry_policy._retry_start_time = time.time() + retry_policy.command_type = CommandType.OTHER + should_retry, msg = retry_policy.should_retry( + "POST", 429, has_retry_after=False + ) + assert should_retry is True + + def test_respect_server_retry_after__survives_new(self): + """urllib3 calls .new() between retries to create a fresh policy instance. + Verify that respect_server_retry_after_header is carried over and still enforced.""" + policy = self._make_retry_policy(respect_server_retry_after_header=True) + policy._retry_start_time = time.time() + policy.command_type = CommandType.OTHER + new_policy = policy.new() + assert new_policy.respect_server_retry_after_header is True + # The new instance should still block retries without Retry-After + should_retry, msg = new_policy.should_retry("POST", 429, has_retry_after=False) + assert should_retry is False + assert "respect_server_retry_after_header" in msg + + def test_respect_server_retry_after__execute_statement_with_retry_after(self): + """EXECUTE_STATEMENT + 429 + Retry-After header → retry""" + policy = self._make_retry_policy(respect_server_retry_after_header=True) + policy._retry_start_time = time.time() + policy.command_type = CommandType.EXECUTE_STATEMENT + should_retry, msg = policy.should_retry("POST", 429, has_retry_after=True) + assert should_retry is True +