Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
9a925ba
DatabaseAccountRetry - initial commit
dibahlfi Aug 13, 2025
edab3e5
Update sdk/cosmos/azure-cosmos/tests/test_retry_policy_async.py
dibahlfi Aug 14, 2025
0b639fd
DatabaseAccountRetry - adding comments
dibahlfi Aug 14, 2025
72935f1
Merge branch 'users/dibahl/client_reset_connection' of https://github…
dibahlfi Aug 14, 2025
d2920df
fixing pylink comments
dibahlfi Aug 14, 2025
043fe06
Merge branch 'main' into users/dibahl/client_reset_connection
dibahlfi Aug 14, 2025
f4061c1
updated CHANGELOG.md
dibahlfi Aug 14, 2025
0dbf808
fixing bug for write region unavailability
dibahlfi Aug 15, 2025
c163a04
fix: comments
dibahlfi Aug 15, 2025
a8ff872
fix: comments
dibahlfi Aug 16, 2025
10f77ad
fix: comments
dibahlfi Aug 17, 2025
716e619
fix: change log update
dibahlfi Aug 19, 2025
079dde0
fix: fix the absolute timeout enforcement provided by the client
dibahlfi Aug 22, 2025
c1ce9f7
merging main and resolving conflicts
dibahlfi Aug 22, 2025
6151bb1
fix: change log update
dibahlfi Aug 22, 2025
1edb28f
fix: fixing comments
dibahlfi Aug 22, 2025
6c8de11
fixing absolute timeout for logical operations
dibahlfi Aug 26, 2025
451d1c9
fix: adding support for absolute timeout
dibahlfi Sep 2, 2025
8ac49a0
fix: fixing read_timeout bugs
dibahlfi Sep 3, 2025
34c6150
cleaning up
dibahlfi Sep 5, 2025
a097f27
fix: refactoring and cleaning up
dibahlfi Sep 6, 2025
4dfa9e2
fix: renaming tests
dibahlfi Sep 6, 2025
5383881
cleaning up
dibahlfi Sep 9, 2025
ad75d86
merging main
dibahlfi Sep 11, 2025
0fabcf4
fix: cleaning up
dibahlfi Sep 12, 2025
4f1caa2
Merge branch 'main' into users/dibahl/absolute_timeout_fix
dibahlfi Sep 12, 2025
8579929
fix: updated CHANGELOG.md
dibahlfi Sep 12, 2025
19e16e5
fix: cleaning up
dibahlfi Sep 12, 2025
bf235e9
fix: cleaning up test
dibahlfi Sep 12, 2025
e3aa929
fix: addressing comments
dibahlfi Sep 21, 2025
296f0ad
fix: removing white spaces
dibahlfi Sep 21, 2025
6270d32
fix: addressing comments
dibahlfi Sep 23, 2025
e7bf27f
fix: cleaning up
dibahlfi Sep 23, 2025
8422a8b
fix: fixing tests
dibahlfi Sep 30, 2025
0574821
fix: fixing tests
dibahlfi Sep 30, 2025
d3585a5
fix: addressing comments
dibahlfi Nov 19, 2025
7df3a21
fix: refactoring
dibahlfi Nov 20, 2025
6e0574d
Merge branch 'main' into users/dibahl/absolute_timeout_fix
dibahlfi Nov 20, 2025
bfa6265
fixing pylint errors
dibahlfi Nov 20, 2025
75ea61d
fix: cleaning up code
dibahlfi Nov 21, 2025
3109ff7
fix: fixing test
dibahlfi Nov 21, 2025
90fc321
resolving conflicts
dibahlfi Nov 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/cosmos/azure-cosmos/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#### Breaking Changes

#### Bugs Fixed
* Fixed bug where client timeout/read_timeout values were not properly enforced[PR 42652](https://github.com/Azure/azure-sdk-for-python/pull/42652).
* Fixed bug when passing in None for some option in `query_items` would cause unexpected errors. See [PR 44098](https://github.com/Azure/azure-sdk-for-python/pull/44098)

#### Other Changes
Expand Down
8 changes: 8 additions & 0 deletions sdk/cosmos/azure-cosmos/azure/cosmos/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"""

import base64
import time
from email.utils import formatdate
import json
import uuid
Expand Down Expand Up @@ -113,6 +114,13 @@ def build_options(kwargs: dict[str, Any]) -> dict[str, Any]:
for key, value in _COMMON_OPTIONS.items():
if key in kwargs:
options[value] = kwargs.pop(key)
if 'read_timeout' in kwargs:
options['read_timeout'] = kwargs['read_timeout']
if 'timeout' in kwargs:
options['timeout'] = kwargs['timeout']


options[Constants.OperationStartTime] = time.time()
if_match, if_none_match = _get_match_headers(kwargs)
if if_match:
options['accessCondition'] = {'type': 'IfMatch', 'condition': if_match}
Expand Down
9 changes: 8 additions & 1 deletion sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,21 @@
from typing_extensions import Literal
# cspell:ignore PPAF

# cspell:ignore reranker
class TimeoutScope:
"""Defines the scope of timeout application"""
OPERATION: Literal["operation"] = "operation" # Apply timeout to entire logical operation
PAGE: Literal["page"] = "page" # Apply timeout to individual page requests

# cspell:ignore reranker

class _Constants:
"""Constants used in the azure-cosmos package"""

UserConsistencyPolicy: Literal["userConsistencyPolicy"] = "userConsistencyPolicy"
DefaultConsistencyLevel: Literal["defaultConsistencyLevel"] = "defaultConsistencyLevel"
OperationStartTime: Literal["operationStartTime"] = "operationStartTime"
# whether to apply timeout to the whole logical operation or just a page request
TimeoutScope: Literal["timeoutScope"] = "timeoutScope"

# GlobalDB related constants
WritableLocations: Literal["writableLocations"] = "writableLocations"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3172,6 +3172,18 @@ def __QueryFeed( # pylint: disable=too-many-locals, too-many-statements, too-ma
"""
if options is None:
options = {}
read_timeout = options.get("read_timeout")
if read_timeout is not None:
# we currently have a gap where kwargs are not getting passed correctly down the pipeline. In order to make
# absolute time out work, we are passing read_timeout via kwargs as a temporary fix
kwargs.setdefault("read_timeout", read_timeout)

operation_start_time = options.get(Constants.OperationStartTime)
if operation_start_time is not None:
kwargs.setdefault(Constants.OperationStartTime, operation_start_time)
timeout = options.get("timeout")
if timeout is not None:
kwargs.setdefault("timeout", timeout)

if query:
__GetBodiesFromQueryResult = result_fn
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,17 @@ async def _fetch_items_helper_no_retries(self, fetch_function):
return fetched_items

async def _fetch_items_helper_with_retries(self, fetch_function):
async def callback():
# TODO: Properly propagate kwargs from retry utility to fetch function
# the callback keep the **kwargs parameter to maintain compatibility with the retry utility's execution pattern.
# ExecuteAsync passes retry context parameters (timeout, operation start time, logger, etc.)
# The callback need to accept these parameters even if unused
# Removing **kwargs results in a TypeError when ExecuteAsync tries to pass these parameters
async def callback(**kwargs): # pylint: disable=unused-argument
return await self._fetch_items_helper_no_retries(fetch_function)

return await _retry_utility_async.ExecuteAsync(self._client, self._client._global_endpoint_manager, callback)
return await _retry_utility_async.ExecuteAsync(
self._client, self._client._global_endpoint_manager, callback, **self._options
)


class _DefaultQueryExecutionContext(_QueryExecutionContextBase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,19 @@ def __init__(self, client, resource_link, query, options, fetch_function,
async def _create_execution_context_with_query_plan(self):
self._fetched_query_plan = True
query_to_use = self._query if self._query is not None else "Select * from root r"
query_execution_info = _PartitionedQueryExecutionInfo(await self._client._GetQueryPlanThroughGateway
(query_to_use, self._resource_link, self._options.get('excludedLocations')))
query_plan = await self._client._GetQueryPlanThroughGateway(
query_to_use,
self._resource_link,
self._options.get('excludedLocations'),
read_timeout=self._options.get('read_timeout')
)
query_execution_info = _PartitionedQueryExecutionInfo(query_plan)
qe_info = getattr(query_execution_info, "_query_execution_info", None)
if isinstance(qe_info, dict) and isinstance(query_to_use, dict):
params = query_to_use.get("parameters")
if params is not None:
query_execution_info._query_execution_info['parameters'] = params

self._execution_context = await self._create_pipelined_execution_context(query_execution_info)

async def __anext__(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,15 @@ def _fetch_items_helper_no_retries(self, fetch_function):
return fetched_items

def _fetch_items_helper_with_retries(self, fetch_function):
def callback():
# TODO: Properly propagate kwargs from retry utility to fetch function
# the callback keep the **kwargs parameter to maintain compatibility with the retry utility's execution pattern.
# ExecuteAsync passes retry context parameters (timeout, operation start time, logger, etc.)
# The callback need to accept these parameters even if unused
# Removing **kwargs results in a TypeError when ExecuteAsync tries to pass these parameters
def callback(**kwargs): # pylint: disable=unused-argument
return self._fetch_items_helper_no_retries(fetch_function)

return _retry_utility.Execute(self._client, self._client._global_endpoint_manager, callback)
return _retry_utility.Execute(self._client, self._client._global_endpoint_manager, callback, **self._options)

next = __next__ # Python 2 compatibility.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,19 @@ def __init__(self, client, resource_link, query, options, fetch_function, respon
def _create_execution_context_with_query_plan(self):
self._fetched_query_plan = True
query_to_use = self._query if self._query is not None else "Select * from root r"
query_execution_info = _PartitionedQueryExecutionInfo(self._client._GetQueryPlanThroughGateway
(query_to_use, self._resource_link, self._options.get('excludedLocations')))

query_plan = self._client._GetQueryPlanThroughGateway(
query_to_use,
self._resource_link,
self._options.get('excludedLocations'),
read_timeout=self._options.get('read_timeout')
)
query_execution_info = _PartitionedQueryExecutionInfo(query_plan)
qe_info = getattr(query_execution_info, "_query_execution_info", None)
if isinstance(qe_info, dict) and isinstance(query_to_use, dict):
params = query_to_use.get("parameters")
if params is not None:
query_execution_info._query_execution_info['parameters'] = params

self._execution_context = self._create_pipelined_execution_context(query_execution_info)

def __next__(self):
Expand Down
14 changes: 14 additions & 0 deletions sdk/cosmos/azure-cosmos/azure/cosmos/_query_iterable.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@

"""Iterable query results in the Azure Cosmos database service.
"""
import time
from azure.core.paging import PageIterator # type: ignore
from azure.cosmos._constants import _Constants, TimeoutScope
from azure.cosmos._execution_context import execution_dispatcher
from azure.cosmos import exceptions

# pylint: disable=protected-access

Expand Down Expand Up @@ -99,6 +102,17 @@ def _fetch_next(self, *args): # pylint: disable=unused-argument
:return: List of results.
:rtype: list
"""
timeout = self._options.get('timeout')
# reset the operation start time if it's a paged request
if timeout and self._options.get(_Constants.TimeoutScope) != TimeoutScope.OPERATION:
self._options[_Constants.OperationStartTime] = time.time()

# Check timeout before fetching next block
if timeout:
elapsed = time.time() - self._options.get(_Constants.OperationStartTime)
if elapsed >= timeout:
raise exceptions.CosmosClientTimeoutError()

block = self._ex_context.fetch_next_block()
if not block:
raise StopIteration
Expand Down
35 changes: 29 additions & 6 deletions sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@
from . import _session_retry_policy
from . import _timeout_failover_retry_policy
from . import exceptions
from ._constants import _Constants
from .documents import _OperationType
from .exceptions import CosmosHttpResponseError
from .http_constants import HttpHeaders, StatusCodes, SubStatusCodes, ResourceType
from ._cosmos_http_logging_policy import _log_diagnostics_error


# pylint: disable=protected-access, disable=too-many-lines, disable=too-many-statements, disable=too-many-branches
# cspell:ignore PPAF,ppaf,ppcb

Expand All @@ -65,6 +65,13 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): # pylin
:returns: the result of running the passed in function as a (result, headers) tuple
:rtype: tuple of (dict, dict)
"""
# Capture the client timeout and start time at the beginning
timeout = kwargs.get('timeout')
operation_start_time = kwargs.get(_Constants.OperationStartTime, time.time())

# Track the last error for chaining
last_error = None

pk_range_wrapper = None
if args and (global_endpoint_manager.is_per_partition_automatic_failover_applicable(args[0]) or
global_endpoint_manager.is_circuit_breaker_applicable(args[0])):
Expand Down Expand Up @@ -115,14 +122,25 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): # pylin
client, client._container_properties_cache, None, *args)

while True:
client_timeout = kwargs.get('timeout')
start_time = time.time()
# Check timeout before executing function
if timeout:
elapsed = time.time() - operation_start_time
if elapsed >= timeout:
raise exceptions.CosmosClientTimeoutError(error=last_error)

try:
if args:
result = ExecuteFunction(function, global_endpoint_manager, *args, **kwargs)
global_endpoint_manager.record_success(args[0], pk_range_wrapper)
else:
result = ExecuteFunction(function, *args, **kwargs)
# Check timeout after successful execution
if timeout:
elapsed = time.time() - operation_start_time
if elapsed >= timeout:
raise exceptions.CosmosClientTimeoutError(error=last_error)

if not client.last_response_headers:
client.last_response_headers = {}

Expand Down Expand Up @@ -163,6 +181,7 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): # pylin

return result
except exceptions.CosmosHttpResponseError as e:
last_error = e
if request:
# update session token for relevant operations
client._UpdateSessionIfRequired(request.headers, {}, e.headers)
Expand Down Expand Up @@ -236,12 +255,13 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): # pylin
client.session.clear_session_token(client.last_response_headers)
raise

# Now check timeout before retrying
if timeout:
elapsed = time.time() - operation_start_time
if elapsed >= timeout:
raise exceptions.CosmosClientTimeoutError(error=last_error)
# Wait for retry_after_in_milliseconds time before the next retry
time.sleep(retry_policy.retry_after_in_milliseconds / 1000.0)
if client_timeout:
kwargs['timeout'] = client_timeout - (time.time() - start_time)
if kwargs['timeout'] <= 0:
raise exceptions.CosmosClientTimeoutError()

except ServiceRequestError as e:
if request and _has_database_account_header(request.headers):
Expand Down Expand Up @@ -270,6 +290,7 @@ def ExecuteFunction(function, *args, **kwargs):
"""
return function(*args, **kwargs)


def _has_read_retryable_headers(request_headers):
if _OperationType.IsReadOnlyOperation(request_headers.get(HttpHeaders.ThinClientProxyOperationType)):
return True
Expand Down Expand Up @@ -345,6 +366,7 @@ def send(self, request):
:raises ~azure.cosmos.exceptions.CosmosClientTimeoutError: Specified timeout exceeded.
:raises ~azure.core.exceptions.ClientAuthenticationError: Authentication failed.
"""

absolute_timeout = request.context.options.pop('timeout', None)
per_request_timeout = request.context.options.pop('connection_timeout', 0)
request_params = request.context.options.pop('request_params', None)
Expand Down Expand Up @@ -397,6 +419,7 @@ def send(self, request):
if retry_active:
self.sleep(retry_settings, request.context.transport)
continue

raise err
except CosmosHttpResponseError as err:
raise err
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from urllib.parse import urlparse
from azure.core.exceptions import DecodeError # type: ignore

from ._constants import _Constants
from . import exceptions, http_constants, _retry_utility
from ._utils import get_user_agent_features

Expand Down Expand Up @@ -80,7 +80,7 @@ def _Request(global_endpoint_manager, request_params, connection_policy, pipelin

"""
# pylint: disable=protected-access, too-many-branches

kwargs.pop(_Constants.OperationStartTime, None)
connection_timeout = connection_policy.RequestTimeout
connection_timeout = kwargs.pop("connection_timeout", connection_timeout)
read_timeout = connection_policy.ReadTimeout
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .. import exceptions
from .. import http_constants
from . import _retry_utility_async
from .._constants import _Constants
from .._synchronized_request import _request_body_from_data, _replace_url_prefix
from .._utils import get_user_agent_features

Expand All @@ -51,7 +52,7 @@ async def _Request(global_endpoint_manager, request_params, connection_policy, p

"""
# pylint: disable=protected-access, too-many-branches

kwargs.pop(_Constants.OperationStartTime, None)
connection_timeout = connection_policy.RequestTimeout
read_timeout = connection_policy.ReadTimeout
connection_timeout = kwargs.pop("connection_timeout", connection_timeout)
Expand Down
14 changes: 11 additions & 3 deletions sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@
from .._base import (_build_properties_cache, _deserialize_throughput, _replace_throughput,
build_options as _build_options, GenerateGuidId, validate_cache_staleness_value)
from .._change_feed.feed_range_internal import FeedRangeInternalEpk
from .._constants import _Constants as Constants

from .._cosmos_responses import CosmosDict, CosmosList
from .._constants import _Constants as Constants, TimeoutScope
from .._routing.routing_range import Range
from .._session_token_helpers import get_latest_session_token
from ..exceptions import CosmosHttpResponseError
Expand Down Expand Up @@ -96,8 +97,14 @@ def __repr__(self) -> str:

async def _get_properties_with_options(self, options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
kwargs = {}
if options and "excludedLocations" in options:
kwargs['excluded_locations'] = options['excludedLocations']
if options:
if "excludedLocations" in options:
kwargs['excluded_locations'] = options['excludedLocations']
if Constants.OperationStartTime in options:
kwargs[Constants.OperationStartTime] = options[Constants.OperationStartTime]
if "timeout" in options:
kwargs['timeout'] = options['timeout']

return await self._get_properties(**kwargs)

async def _get_properties(self, **kwargs: Any) -> dict[str, Any]:
Expand Down Expand Up @@ -484,6 +491,7 @@ async def read_items(
query_options = _build_options(kwargs)
await self._get_properties_with_options(query_options)
query_options["enableCrossPartitionQuery"] = True
query_options[Constants.TimeoutScope] = TimeoutScope.OPERATION

item_tuples = [(item_id, await self._set_partition_key(pk)) for item_id, pk in items]
return await self.client_connection.read_items(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2979,6 +2979,21 @@ async def __QueryFeed( # pylint: disable=too-many-branches,too-many-statements,
if options is None:
options = {}

read_timeout = options.get("read_timeout")
if read_timeout is not None:
# we currently have a gap where kwargs are not getting passed correctly down the pipeline. In order to make
# absolute time out work, we are passing read_timeout via kwargs as a temporary fix
kwargs.setdefault("read_timeout", read_timeout)

operation_start_time = options.get(Constants.OperationStartTime)
if operation_start_time is not None:
# we need to set operation_state in kwargs as thats where it is looked at while sending the request
kwargs.setdefault(Constants.OperationStartTime, operation_start_time)
timeout = options.get("timeout")
if timeout is not None:
# we need to set operation_state in kwargs as that's where it is looked at while sending the request
kwargs.setdefault("timeout", timeout)

if query:
__GetBodiesFromQueryResult = result_fn
else:
Expand Down Expand Up @@ -3383,7 +3398,7 @@ async def _GetQueryPlanThroughGateway(self, query: str, resource_link: str,
"contentType": runtime_constants.MediaTypes.Json,
"isQueryPlanRequest": True,
"supportedQueryFeatures": supported_query_features,
"queryVersion": http_constants.Versions.QueryVersion
"queryVersion": http_constants.Versions.QueryVersion,
}
if excluded_locations is not None:
options["excludedLocations"] = excluded_locations
Expand Down
Loading
Loading