diff --git a/b2sdk/_internal/file_version.py b/b2sdk/_internal/file_version.py index ae0ccf699..655d9df69 100644 --- a/b2sdk/_internal/file_version.py +++ b/b2sdk/_internal/file_version.py @@ -471,6 +471,10 @@ def expires_parsed(self) -> dt.datetime | None: return None return parse_http_date(self.expires) + @property + def _should_be_decoded(self) -> bool: + return bool(self.content_encoding and self.api.api_config.decode_content) + def as_dict(self) -> dict: result = super().as_dict() if self.cache_control is not None: diff --git a/b2sdk/_internal/transfer/inbound/downloaded_file.py b/b2sdk/_internal/transfer/inbound/downloaded_file.py index 7242b19df..3abf8dd56 100644 --- a/b2sdk/_internal/transfer/inbound/downloaded_file.py +++ b/b2sdk/_internal/transfer/inbound/downloaded_file.py @@ -170,29 +170,22 @@ def __init__( self.check_hash = check_hash def _validate_download(self, bytes_read, actual_sha1): + desired_length = self.range_[1] - self.range_[0] + 1 if self.range_ is not None else self.download_version.content_length + if bytes_read != desired_length: + raise TruncatedOutput(bytes_read, desired_length) + if ( - self.download_version.content_encoding is not None - and self.download_version.api.api_config.decode_content + not self.download_version._should_be_decoded + and self.check_hash + and self.range_ is None + and self.download_version.content_sha1 != 'none' + and actual_sha1 != self.download_version.content_sha1 ): - return - if self.range_ is None: - if bytes_read != self.download_version.content_length: - raise TruncatedOutput(bytes_read, self.download_version.content_length) - - if ( - self.check_hash - and self.download_version.content_sha1 != 'none' - and actual_sha1 != self.download_version.content_sha1 - ): - raise ChecksumMismatch( - checksum_type='sha1', - expected=self.download_version.content_sha1, - actual=actual_sha1, - ) - else: - desired_length = self.range_[1] - self.range_[0] + 1 - if bytes_read != desired_length: - raise TruncatedOutput(bytes_read, desired_length) + raise ChecksumMismatch( + checksum_type='sha1', + expected=self.download_version.content_sha1, + actual=actual_sha1, + ) def save(self, file: BinaryIO, allow_seeking: bool | None = None) -> None: """ diff --git a/b2sdk/_internal/transfer/inbound/downloader/abstract.py b/b2sdk/_internal/transfer/inbound/downloader/abstract.py index 86c71d22d..360026c4e 100644 --- a/b2sdk/_internal/transfer/inbound/downloader/abstract.py +++ b/b2sdk/_internal/transfer/inbound/downloader/abstract.py @@ -118,8 +118,7 @@ def is_suitable(self, download_version: DownloadVersion, allow_seeking: bool): return False if ( not self.SUPPORTS_DECODE_CONTENT - and download_version.content_encoding - and download_version.api.api_config.decode_content + and download_version._should_be_decoded ): return False return True diff --git a/b2sdk/_internal/transfer/inbound/downloader/simple.py b/b2sdk/_internal/transfer/inbound/downloader/simple.py index 87fabac05..2452bb1ad 100644 --- a/b2sdk/_internal/transfer/inbound/downloader/simple.py +++ b/b2sdk/_internal/transfer/inbound/downloader/simple.py @@ -12,6 +12,7 @@ import logging from io import IOBase +from requests.exceptions import ChunkedEncodingError, ConnectionError, ContentDecodingError from requests.models import Response from b2sdk._internal.encryption.setting import EncryptionSetting @@ -41,12 +42,18 @@ def _download( response.close() return 0, digest.hexdigest() chunk_size = self._get_chunk_size(actual_size) + should_be_decoded = download_version._should_be_decoded decoded_bytes_read = 0 - for data in response.iter_content(chunk_size=chunk_size): - file.write(data) - digest.update(data) - decoded_bytes_read += len(data) + try: + for data in response.iter_content(chunk_size=chunk_size): + file.write(data) + digest.update(data) + decoded_bytes_read += len(data) + except (ChunkedEncodingError, ConnectionError, ContentDecodingError) as exc: + if should_be_decoded: + raise # cannot resume a partially decoded stream + logger.debug('Stream read error during download, will retry if needed: %s', exc) bytes_read = response.raw.tell() response.close() @@ -58,9 +65,10 @@ def _download( # or something and the server closes connection, while neither tcp or http have a problem # with the truncated output, so we detect it here and try to continue - num_tries = 5 # this is hardcoded because we are going to replace the entire retry interface soon, so we'll avoid deprecation here and keep it private - retries_left = num_tries - 1 - while retries_left and bytes_read < download_version.content_length: + retries_left = 4 # this is hardcoded because we are going to replace the entire retry interface soon, so we'll avoid deprecation here and keep it private + while ( + bytes_read < download_version.content_length and not should_be_decoded and retries_left + ): new_range = self._get_remote_range( response, download_version, @@ -79,12 +87,15 @@ def _download( new_range.as_tuple(), encryption=encryption, ) as followup_response: - for data in followup_response.iter_content( - chunk_size=self._get_chunk_size(actual_size) - ): - file.write(data) - digest.update(data) - decoded_bytes_read += len(data) + try: + for data in followup_response.iter_content( + chunk_size=self._get_chunk_size(actual_size) + ): + file.write(data) + digest.update(data) + decoded_bytes_read += len(data) + except (ChunkedEncodingError, ConnectionError, ContentDecodingError) as exc: + logger.debug('Stream read error during download, will retry if needed: %s', exc) bytes_read += followup_response.raw.tell() retries_left -= 1 return bytes_read, digest.hexdigest() diff --git a/b2sdk/v1/api.py b/b2sdk/v1/api.py index fafbe7d88..64d820f9e 100644 --- a/b2sdk/v1/api.py +++ b/b2sdk/v1/api.py @@ -68,6 +68,7 @@ def __init__( raw_api=raw_api, api_config=api_config, ) + self.api_config = api_config self.file_version_factory = self.FILE_VERSION_FACTORY_CLASS(self) self.download_version_factory = self.DOWNLOAD_VERSION_FACTORY_CLASS(self) self.services = Services( diff --git a/changelog.d/+read-error-retry.fixed.md b/changelog.d/+read-error-retry.fixed.md new file mode 100644 index 000000000..241a89682 --- /dev/null +++ b/changelog.d/+read-error-retry.fixed.md @@ -0,0 +1 @@ +Retry stream read errors during download in `SimpleDownloader`. diff --git a/test/unit/internal/transfer/downloader/test_simple.py b/test/unit/internal/transfer/downloader/test_simple.py new file mode 100644 index 000000000..b510aefb2 --- /dev/null +++ b/test/unit/internal/transfer/downloader/test_simple.py @@ -0,0 +1,188 @@ +###################################################################### +# +# File: test/unit/internal/transfer/downloader/test_simple.py +# +# Copyright 2026 Backblaze Inc. All Rights Reserved. +# +# License https://www.backblaze.com/using_b2_code.html +# +###################################################################### +import os +from collections.abc import Callable, Iterator +from io import BytesIO +from itertools import count +from types import ModuleType +from typing import Any + +import pytest +from apiver_deps import B2Api, Bucket, DownloadVersion, SimpleDownloader +from requests.exceptions import ChunkedEncodingError, ConnectionError, ContentDecodingError +from requests.models import Response +from urllib3.exceptions import DecodeError, IncompleteRead, ProtocolError, ReadTimeoutError + +CHUNKED_ENCODING_ERROR = ChunkedEncodingError( + ProtocolError( + 'Connection broken: IncompleteRead(1 bytes read, 99 more expected)', + IncompleteRead(1, 99), + ) +) +CONTENT_DECODING_ERROR = ContentDecodingError( + DecodeError('Error -3 while decompressing data: incorrect header check') +) +CONNECTION_ERROR = ConnectionError(ReadTimeoutError(None, None, 'Read timed out.')) + + +@pytest.fixture +def file_size() -> int: + return 100 + + +@pytest.fixture +def file_content(file_size: int) -> bytes: + return os.urandom(file_size) + + +@pytest.fixture +def mock_download_response( + apiver_module: ModuleType, + bucket: Bucket, + file_content: bytes, +) -> tuple[Response, DownloadVersion]: + file_version = bucket.upload_bytes(file_content, f'dummy_file_{len(file_content)}.txt') + + url = bucket.api.session.get_download_url_by_name(bucket.name, file_version.file_name) + response = bucket.api.services.session.download_file_from_url(url).__enter__() + + return ( + response, + apiver_module.DownloadVersionFactory(bucket.api).from_response_headers(response.headers), + ) + + +@pytest.fixture +def output_file() -> BytesIO: + return BytesIO() + + +@pytest.fixture +def downloader(apiver_module: ModuleType) -> SimpleDownloader: + return apiver_module.SimpleDownloader(force_chunk_size=5) + + +def _make_iter_content( + response: Response, + attempts: Iterator[int], + fail_count: int, + stream_error: ChunkedEncodingError | ConnectionError | ContentDecodingError, +) -> Callable[..., Iterator[bytes]]: + def iter_content(chunk_size: int = 1, decode_unicode: bool = False) -> Iterator[bytes]: + attempt = next(attempts) + chunk = response.raw.read(1) + if chunk: + yield chunk + if attempt <= fail_count: + raise stream_error + while True: + chunk = response.raw.read(chunk_size) + if not chunk: + break + yield chunk + + return iter_content + + +@pytest.mark.parametrize('fail_count', [0, 1, 2, 4, 5]) +@pytest.mark.parametrize( + 'stream_error', + [ + pytest.param(CHUNKED_ENCODING_ERROR, id='ChunkedEncodingError'), + pytest.param(CONNECTION_ERROR, id='ConnectionError'), + pytest.param(CONTENT_DECODING_ERROR, id='ContentDecodingError'), + ], +) +def test_download_file__stream_read_error( + b2api: B2Api, + bucket: Bucket, + downloader: SimpleDownloader, + output_file: BytesIO, + file_size: int, + file_content: bytes, + mock_download_response: tuple[Response, DownloadVersion], + fail_count: int, + stream_error: ChunkedEncodingError | ConnectionError | ContentDecodingError, +) -> None: + mock_response, download_version = mock_download_response + + attempts = count(1) + mock_response.iter_content = _make_iter_content( + mock_response, attempts, fail_count, stream_error + ) + + download_func = bucket.api.services.session.download_file_from_url + + def download_func_mock(*args: Any, **kwargs: Any) -> Response: + response = download_func(*args, **kwargs).__enter__() + response.iter_content = _make_iter_content(response, attempts, fail_count, stream_error) + return response + + bucket.api.services.session.download_file_from_url = download_func_mock + + bytes_written, _ = downloader.download( + output_file, mock_response, download_version, b2api.session + ) + + if fail_count < 5: + assert bytes_written == file_size + assert output_file.getvalue() == file_content + else: + assert bytes_written == fail_count + + +@pytest.mark.parametrize( + 'stream_error', + [ + pytest.param(CHUNKED_ENCODING_ERROR, id='ChunkedEncodingError'), + pytest.param(CONNECTION_ERROR, id='ConnectionError'), + pytest.param(CONTENT_DECODING_ERROR, id='ContentDecodingError'), + ], +) +def test_download_file__decoded_stream_stream_read_error_reraises( + b2api: B2Api, + bucket: Bucket, + downloader: SimpleDownloader, + output_file: BytesIO, + file_content: bytes, + mock_download_response: tuple[Response, DownloadVersion], + stream_error: ChunkedEncodingError | ConnectionError | ContentDecodingError, +) -> None: + """ + Test that a stream read error during a decoded stream download is re-raised and not retried + """ + + mock_response, download_version = mock_download_response + download_version.content_encoding = 'gzip' + download_version.api.api_config.decode_content = True + + attempts = count(1) + mock_response.iter_content = _make_iter_content( + mock_response, attempts, 1, stream_error + ) + + followup_calls = 0 + download_func = bucket.api.services.session.download_file_from_url + + def download_func_mock(*args: Any, **kwargs: Any) -> Response: + nonlocal followup_calls + followup_calls += 1 + response = download_func(*args, **kwargs).__enter__() + response.iter_content = _make_iter_content(response, attempts, 1, stream_error) + return response + + bucket.api.services.session.download_file_from_url = download_func_mock + + with pytest.raises(type(stream_error)): + downloader.download( + output_file, mock_response, download_version, b2api.session + ) + + assert followup_calls == 0 diff --git a/test/unit/internal/transfer/test_downloaded_file.py b/test/unit/internal/transfer/test_downloaded_file.py new file mode 100644 index 000000000..88f4277ba --- /dev/null +++ b/test/unit/internal/transfer/test_downloaded_file.py @@ -0,0 +1,65 @@ +###################################################################### +# +# File: test/unit/internal/transfer/test_downloaded_file.py +# +# Copyright 2026 Backblaze Inc. All Rights Reserved. +# +# License https://www.backblaze.com/using_b2_code.html +# +###################################################################### +from unittest.mock import Mock + +import pytest +from apiver_deps import DownloadedFile +from apiver_deps_exception import ChecksumMismatch, TruncatedOutput + + +def _generate_downloaded_file( + *, + decode_content: bool, + content_length: int = 100, + content_sha1: str = 'abc', + range_: tuple[int, int] | None = None, + check_hash: bool = True, +): + download_version = Mock() + download_version.content_encoding = 'gzip' if decode_content else None + download_version.content_length = content_length + download_version.content_sha1 = content_sha1 + download_version.api.api_config.decode_content = decode_content + download_version._should_be_decoded = decode_content + return DownloadedFile( + download_version=download_version, + download_manager=Mock(), + range_=range_, + response=Mock(), + encryption=None, + progress_listener=Mock(), + check_hash=check_hash, + ) + + +@pytest.mark.parametrize('decode_content', [True, False]) +def test_validate_download_truncated_full_download(decode_content): + # range not set, length doesn't match + downloaded_file = _generate_downloaded_file(decode_content=decode_content) + with pytest.raises(TruncatedOutput): + downloaded_file._validate_download(99, 'abc') + + +@pytest.mark.parametrize('decode_content', [True, False]) +def test_validate_download_truncated_range_download(decode_content): + # range set, length doesn't match + downloaded_file = _generate_downloaded_file(decode_content=decode_content, range_=(10, 19)) + with pytest.raises(TruncatedOutput): + downloaded_file._validate_download(9, 'abc') + + +@pytest.mark.parametrize('decode_content', [True, False]) +def test_validate_download_hash_check(decode_content): + downloaded_file = _generate_downloaded_file(decode_content=decode_content, check_hash=True) + if decode_content: + downloaded_file._validate_download(100, 'wrong') + else: + with pytest.raises(ChecksumMismatch): + downloaded_file._validate_download(100, 'wrong')