From cee54d92672c855adb6764033cbfb5dd073cc3fb Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Thu, 27 Jun 2024 14:55:22 -0700 Subject: [PATCH 1/8] Use generator to reduce memory footprint We can return True on the first positive. And we don't need to keep track of the files. --- airflow/providers/amazon/aws/hooks/s3.py | 35 +++++++++++------------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/s3.py b/airflow/providers/amazon/aws/hooks/s3.py index a65c79a726ead..1afbe0fbc40bd 100644 --- a/airflow/providers/amazon/aws/hooks/s3.py +++ b/airflow/providers/amazon/aws/hooks/s3.py @@ -36,7 +36,7 @@ from io import BytesIO from pathlib import Path from tempfile import NamedTemporaryFile, gettempdir -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any, AsyncIterator, Callable from urllib.parse import urlsplit from uuid import uuid4 @@ -465,7 +465,7 @@ async def list_prefixes_async( @provide_bucket_name_async async def get_file_metadata_async( self, client: AioBaseClient, bucket_name: str, key: str | None = None - ) -> list[Any]: + ) -> AsyncIterator[Any]: """ Get a list of files that a key matching a wildcard expression exists in a bucket asynchronously. @@ -477,11 +477,10 @@ async def get_file_metadata_async( delimiter = "" paginator = client.get_paginator("list_objects_v2") response = paginator.paginate(Bucket=bucket_name, Prefix=prefix, Delimiter=delimiter) - files = [] async for page in response: if "Contents" in page: - files += page["Contents"] - return files + for row in page["Contents"]: + yield row async def _check_key_async( self, @@ -506,21 +505,19 @@ async def _check_key_async( """ bucket_name, key = self.get_s3_bucket_key(bucket_val, key, "bucket_name", "bucket_key") if wildcard_match: - keys = await self.get_file_metadata_async(client, bucket_name, key) - key_matches = [k for k in keys if fnmatch.fnmatch(k["Key"], key)] - if not key_matches: - return False - elif use_regex: - keys = await self.get_file_metadata_async(client, bucket_name) - key_matches = [k for k in keys if re.match(pattern=key, string=k["Key"])] - if not key_matches: - return False + async for k in self.get_file_metadata_async(client, bucket_name, key): + if fnmatch.fnmatch(k["Key"], key): + return True + return False + if use_regex: + async for k in self.get_file_metadata_async(client, bucket_name): + if re.match(pattern=key, string=k["Key"]): + return True + return False + if await self.get_head_object_async(client, key, bucket_name): + return True else: - obj = await self.get_head_object_async(client, key, bucket_name) - if obj is None: - return False - - return True + return False async def check_key_async( self, From 14c9111796cccbdb09cf434d5329924d309717fd Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Thu, 27 Jun 2024 17:32:58 -0700 Subject: [PATCH 2/8] add tests --- airflow/providers/amazon/aws/hooks/s3.py | 77 +++++++----- tests/providers/amazon/aws/hooks/test_s3.py | 122 +++++++++++++------- 2 files changed, 123 insertions(+), 76 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/s3.py b/airflow/providers/amazon/aws/hooks/s3.py index 1afbe0fbc40bd..0738f32957653 100644 --- a/airflow/providers/amazon/aws/hooks/s3.py +++ b/airflow/providers/amazon/aws/hooks/s3.py @@ -22,6 +22,7 @@ import asyncio import fnmatch import gzip as gz +import inspect import logging import os import re @@ -62,41 +63,15 @@ def provide_bucket_name(func: Callable) -> Callable: + """Provide a bucket name taken from the connection if no bucket name has been passed to the function.""" """Provide a bucket name taken from the connection if no bucket name has been passed to the function.""" if hasattr(func, "_unify_bucket_name_and_key_wrapped"): logger.warning("`unify_bucket_name_and_key` should wrap `provide_bucket_name`.") - function_signature = signature(func) - - @wraps(func) - def wrapper(*args, **kwargs) -> Callable: - bound_args = function_signature.bind(*args, **kwargs) - - if "bucket_name" not in bound_args.arguments: - self = args[0] - - if "bucket_name" in self.service_config: - bound_args.arguments["bucket_name"] = self.service_config["bucket_name"] - elif self.conn_config and self.conn_config.schema: - warnings.warn( - "s3 conn_type, and the associated schema field, is deprecated. " - "Please use aws conn_type instead, and specify `bucket_name` " - "in `service_config.s3` within `extras`.", - AirflowProviderDeprecationWarning, - stacklevel=2, - ) - bound_args.arguments["bucket_name"] = self.conn_config.schema - return func(*bound_args.args, **bound_args.kwargs) - - return wrapper - - -def provide_bucket_name_async(func: Callable) -> Callable: - """Provide a bucket name taken from the connection if no bucket name has been passed to the function.""" function_signature = signature(func) - @wraps(func) - async def wrapper(*args: Any, **kwargs: Any) -> Any: + # todo: raise immediately if func has no bucket_name arg + async def maybe_add_bucket_name(*args, **kwargs): bound_args = function_signature.bind(*args, **kwargs) if "bucket_name" not in bound_args.arguments: @@ -105,8 +80,46 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any: connection = await sync_to_async(self.get_connection)(self.aws_conn_id) if connection.schema: bound_args.arguments["bucket_name"] = connection.schema + return bound_args + + if inspect.iscoroutinefunction(func): + + @wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> Any: + bound_args = await maybe_add_bucket_name(*args, **kwargs) + print(f"invoking async function {func=}") + return await func(*bound_args.args, **bound_args.kwargs) + + elif inspect.isasyncgenfunction(func): + + @wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> Any: + bound_args = await maybe_add_bucket_name(*args, **kwargs) + async for thing in func(*bound_args.args, **bound_args.kwargs): + yield thing + + else: + + @wraps(func) + def wrapper(*args, **kwargs) -> Callable: + bound_args = function_signature.bind(*args, **kwargs) + + if "bucket_name" not in bound_args.arguments: + self = args[0] + + if "bucket_name" in self.service_config: + bound_args.arguments["bucket_name"] = self.service_config["bucket_name"] + elif self.conn_config and self.conn_config.schema: + warnings.warn( + "s3 conn_type, and the associated schema field, is deprecated. " + "Please use aws conn_type instead, and specify `bucket_name` " + "in `service_config.s3` within `extras`.", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + bound_args.arguments["bucket_name"] = self.conn_config.schema - return await func(*bound_args.args, **bound_args.kwargs) + return func(*bound_args.args, **bound_args.kwargs) return wrapper @@ -400,8 +413,8 @@ def list_prefixes( return prefixes - @provide_bucket_name_async @unify_bucket_name_and_key + @provide_bucket_name async def get_head_object_async( self, client: AioBaseClient, key: str, bucket_name: str | None = None ) -> dict[str, Any] | None: @@ -462,7 +475,7 @@ async def list_prefixes_async( return prefixes - @provide_bucket_name_async + @provide_bucket_name async def get_file_metadata_async( self, client: AioBaseClient, bucket_name: str, key: str | None = None ) -> AsyncIterator[Any]: diff --git a/tests/providers/amazon/aws/hooks/test_s3.py b/tests/providers/amazon/aws/hooks/test_s3.py index eb8a883f64c15..2563af5ba70d2 100644 --- a/tests/providers/amazon/aws/hooks/test_s3.py +++ b/tests/providers/amazon/aws/hooks/test_s3.py @@ -23,7 +23,7 @@ import re from datetime import datetime as std_datetime, timezone from unittest import mock, mock as async_mock -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import AsyncMock, MagicMock, Mock, patch from urllib.parse import parse_qs import boto3 @@ -428,8 +428,9 @@ async def test_s3_key_hook_get_file_metadata_async(self, mock_client): s3_hook_async = S3Hook(client_type="S3") mock_client.get_paginator = mock.Mock(return_value=mock_paginator) - task = await s3_hook_async.get_file_metadata_async(mock_client, "test_bucket", "test*") - assert task == [ + keys = [x async for x in s3_hook_async.get_file_metadata_async(mock_client, "test_bucket", "test*")] + + assert keys == [ {"Key": "test_key", "ETag": "etag1", "LastModified": datetime(2020, 8, 14, 17, 19, 34)}, {"Key": "test_key2", "ETag": "etag2", "LastModified": datetime(2020, 8, 14, 17, 19, 34)}, ] @@ -632,64 +633,90 @@ async def test_s3_prefix_sensor_hook_check_for_prefix_async( @pytest.mark.asyncio @async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_s3_bucket_key") - @async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.get_head_object_async") @async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.async_conn") - async def test__check_key_async_without_wildcard_match( - self, mock_client, mock_head_object, mock_get_bucket_key - ): + async def test__check_key_async_without_wildcard_match(self, mock_get_conn, mock_get_bucket_key): """Test _check_key_async function without using wildcard_match""" mock_get_bucket_key.return_value = "test_bucket", "test.txt" - mock_head_object.return_value = {"ContentLength": 0} + mock_client = mock_get_conn.return_value + mock_client.head_object = AsyncMock(return_value={"ContentLength": 0}) s3_hook_async = S3Hook(client_type="S3", resource_type="S3") response = await s3_hook_async._check_key_async( - mock_client.return_value, "test_bucket", False, "s3://test_bucket/file/test.txt" + mock_client, "test_bucket", False, "s3://test_bucket/file/test.txt" ) assert response is True @pytest.mark.asyncio @async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_s3_bucket_key") - @async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.get_head_object_async") @async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.async_conn") async def test_s3__check_key_async_without_wildcard_match_and_get_none( - self, mock_client, mock_head_object, mock_get_bucket_key + self, mock_get_conn, mock_get_bucket_key ): """Test _check_key_async function when get head object returns none""" mock_get_bucket_key.return_value = "test_bucket", "test.txt" - mock_head_object.return_value = None s3_hook_async = S3Hook(client_type="S3", resource_type="S3") + mock_client = mock_get_conn.return_value + mock_client.head_object = AsyncMock(return_value=None) response = await s3_hook_async._check_key_async( - mock_client.return_value, "test_bucket", False, "s3://test_bucket/file/test.txt" + mock_client, "test_bucket", False, "s3://test_bucket/file/test.txt" ) assert response is False + # @async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_s3_bucket_key") @pytest.mark.asyncio - @async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_s3_bucket_key") - @async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.get_file_metadata_async") @async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.async_conn") - async def test_s3__check_key_async_with_wildcard_match( - self, mock_client, mock_get_file_metadata, mock_get_bucket_key - ): + @pytest.mark.parametrize( + "contents, result", + [ + ( + [ + { + "Key": "test/example_s3_test_file.txt", + "ETag": "etag1", + "LastModified": datetime(2020, 8, 14, 17, 19, 34), + "Size": 0, + }, + { + "Key": "test_key2", + "ETag": "etag2", + "LastModified": datetime(2020, 8, 14, 17, 19, 34), + "Size": 0, + }, + ], + True, + ), + ( + [ + { + "Key": "test/example_aeoua.txt", + "ETag": "etag1", + "LastModified": datetime(2020, 8, 14, 17, 19, 34), + "Size": 0, + }, + { + "Key": "test_key2", + "ETag": "etag2", + "LastModified": datetime(2020, 8, 14, 17, 19, 34), + "Size": 0, + }, + ], + False, + ), + ], + ) + async def test_s3__check_key_async_with_wildcard_match(self, mock_get_conn, contents, result): """Test _check_key_async function""" - mock_get_bucket_key.return_value = "test_bucket", "test" - mock_get_file_metadata.return_value = [ - { - "Key": "test_key", - "ETag": "etag1", - "LastModified": datetime(2020, 8, 14, 17, 19, 34), - "Size": 0, - }, - { - "Key": "test_key2", - "ETag": "etag2", - "LastModified": datetime(2020, 8, 14, 17, 19, 34), - "Size": 0, - }, - ] + client = mock_get_conn.return_value + paginator = client.get_paginator.return_value + r = paginator.paginate.return_value + r.__aiter__.return_value = [{"Contents": contents}] s3_hook_async = S3Hook(client_type="S3", resource_type="S3") response = await s3_hook_async._check_key_async( - mock_client.return_value, "test_bucket", True, "test/example_s3_test_file.txt" + client=client, + bucket_val="test_bucket", + wildcard_match=True, + key="test/example_s3_test_file.txt", ) - assert response is False + assert response is result @pytest.mark.parametrize( "key, pattern, expected", @@ -701,24 +728,31 @@ async def test_s3__check_key_async_with_wildcard_match( ) @pytest.mark.asyncio @async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_s3_bucket_key") - @async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.get_file_metadata_async") @async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.async_conn") async def test__check_key_async_with_use_regex( - self, mock_client, mock_get_file_metadata, mock_get_bucket_key, key, pattern, expected + self, mock_get_conn, mock_get_bucket_key, key, pattern, expected ): """Match AWS S3 key with regex expression""" mock_get_bucket_key.return_value = "test_bucket", pattern - mock_get_file_metadata.return_value = [ + client = mock_get_conn.return_value + paginator = client.get_paginator.return_value + r = paginator.paginate.return_value + r.__aiter__.return_value = [ { - "Key": key, - "ETag": "etag1", - "LastModified": datetime(2020, 8, 14, 17, 19, 34), - "Size": 0, - }, + "Contents": [ + { + "Key": key, + "ETag": "etag1", + "LastModified": datetime(2020, 8, 14, 17, 19, 34), + "Size": 0, + }, + ] + } ] + s3_hook_async = S3Hook(client_type="S3", resource_type="S3") response = await s3_hook_async._check_key_async( - client=mock_client.return_value, + client=client, bucket_val="test_bucket", wildcard_match=False, key=pattern, From 5e7298a36a784a24771b18c7354b15a6e0b41de9 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Thu, 27 Jun 2024 17:36:11 -0700 Subject: [PATCH 3/8] fixup --- airflow/providers/amazon/aws/hooks/s3.py | 1 - 1 file changed, 1 deletion(-) diff --git a/airflow/providers/amazon/aws/hooks/s3.py b/airflow/providers/amazon/aws/hooks/s3.py index 0738f32957653..d0402a42945ef 100644 --- a/airflow/providers/amazon/aws/hooks/s3.py +++ b/airflow/providers/amazon/aws/hooks/s3.py @@ -63,7 +63,6 @@ def provide_bucket_name(func: Callable) -> Callable: - """Provide a bucket name taken from the connection if no bucket name has been passed to the function.""" """Provide a bucket name taken from the connection if no bucket name has been passed to the function.""" if hasattr(func, "_unify_bucket_name_and_key_wrapped"): logger.warning("`unify_bucket_name_and_key` should wrap `provide_bucket_name`.") From cb2f31aaf0ab67d4718b2addb414ee9c33c839bc Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Fri, 28 Jun 2024 09:50:37 -0700 Subject: [PATCH 4/8] Update airflow/providers/amazon/aws/hooks/s3.py Co-authored-by: Vincent <97131062+vincbeck@users.noreply.github.com> --- airflow/providers/amazon/aws/hooks/s3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/providers/amazon/aws/hooks/s3.py b/airflow/providers/amazon/aws/hooks/s3.py index d0402a42945ef..1e891803f2929 100644 --- a/airflow/providers/amazon/aws/hooks/s3.py +++ b/airflow/providers/amazon/aws/hooks/s3.py @@ -526,7 +526,7 @@ async def _check_key_async( if re.match(pattern=key, string=k["Key"]): return True return False - if await self.get_head_object_async(client, key, bucket_name): + return bool(await self.get_head_object_async(client, key, bucket_name)) return True else: return False From 6b4a405a4216e61f3986ebcf00d5035c14a1c17f Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Fri, 28 Jun 2024 09:51:32 -0700 Subject: [PATCH 5/8] Revert "Update airflow/providers/amazon/aws/hooks/s3.py" This reverts commit cb2f31aaf0ab67d4718b2addb414ee9c33c839bc. --- airflow/providers/amazon/aws/hooks/s3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/providers/amazon/aws/hooks/s3.py b/airflow/providers/amazon/aws/hooks/s3.py index 1e891803f2929..d0402a42945ef 100644 --- a/airflow/providers/amazon/aws/hooks/s3.py +++ b/airflow/providers/amazon/aws/hooks/s3.py @@ -526,7 +526,7 @@ async def _check_key_async( if re.match(pattern=key, string=k["Key"]): return True return False - return bool(await self.get_head_object_async(client, key, bucket_name)) + if await self.get_head_object_async(client, key, bucket_name): return True else: return False From a92e3b45cd568d85e725baf4c4c3f02ef6a095d0 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Fri, 28 Jun 2024 09:53:49 -0700 Subject: [PATCH 6/8] reapply vincent's suggestion --- airflow/providers/amazon/aws/hooks/s3.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/s3.py b/airflow/providers/amazon/aws/hooks/s3.py index d0402a42945ef..40940f0ec756a 100644 --- a/airflow/providers/amazon/aws/hooks/s3.py +++ b/airflow/providers/amazon/aws/hooks/s3.py @@ -526,10 +526,7 @@ async def _check_key_async( if re.match(pattern=key, string=k["Key"]): return True return False - if await self.get_head_object_async(client, key, bucket_name): - return True - else: - return False + return bool(await self.get_head_object_async(client, key, bucket_name)) async def check_key_async( self, From 42035cedeaa1ce66ef7cf4ac58734401f305a654 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Fri, 28 Jun 2024 10:32:48 -0700 Subject: [PATCH 7/8] add check for param --- airflow/providers/amazon/aws/hooks/s3.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/airflow/providers/amazon/aws/hooks/s3.py b/airflow/providers/amazon/aws/hooks/s3.py index 40940f0ec756a..cfd4b833cf9c6 100644 --- a/airflow/providers/amazon/aws/hooks/s3.py +++ b/airflow/providers/amazon/aws/hooks/s3.py @@ -68,6 +68,10 @@ def provide_bucket_name(func: Callable) -> Callable: logger.warning("`unify_bucket_name_and_key` should wrap `provide_bucket_name`.") function_signature = signature(func) + if "bucket_name" not in function_signature.parameters: + raise RuntimeError( + "Decorator provide_bucket_name should only wrap a function with param 'bucket_name'." + ) # todo: raise immediately if func has no bucket_name arg async def maybe_add_bucket_name(*args, **kwargs): From 38d4350f6677f554686cdb2814b12ba9ccc57408 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Fri, 28 Jun 2024 10:51:21 -0700 Subject: [PATCH 8/8] add changelog --- airflow/providers/amazon/CHANGELOG.rst | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/airflow/providers/amazon/CHANGELOG.rst b/airflow/providers/amazon/CHANGELOG.rst index 38ee51565068d..51e1851e012ba 100644 --- a/airflow/providers/amazon/CHANGELOG.rst +++ b/airflow/providers/amazon/CHANGELOG.rst @@ -26,6 +26,22 @@ Changelog --------- +main +.... + +Bug Fixes +~~~~~~~~~ + +* Reduce memory footprint of s3 key trigger (#40473) + - Decorator ``provide_bucket_name_async`` removed + * We do not need a separate decorator for async. The old one is removed and users can use ``provide_bucket_name`` + for coroutine functions, async iterators, and normal synchronous functions. + - Hook method ``get_file_metadata_async`` is now an async iterator + * Previously, the metadata objects were accumulated in a list. Now the objects are yielded as we page + through the results. To get a list you may use ``async for`` in a list comprehension. + - S3KeyTrigger avoids loading all positive matches into memory in some circumstances + + 8.25.0 ......