Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions airflow/providers/amazon/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,22 @@
Changelog
---------

main
....

Bug Fixes
~~~~~~~~~

* Reduce memory footprint of s3 key trigger (#40473)
- Decorator ``provide_bucket_name_async`` removed
* We do not need a separate decorator for async. The old one is removed and users can use ``provide_bucket_name``
for coroutine functions, async iterators, and normal synchronous functions.
- Hook method ``get_file_metadata_async`` is now an async iterator
* Previously, the metadata objects were accumulated in a list. Now the objects are yielded as we page
through the results. To get a list you may use ``async for`` in a list comprehension.
- S3KeyTrigger avoids loading all positive matches into memory in some circumstances


8.25.0
......

Expand Down
114 changes: 62 additions & 52 deletions airflow/providers/amazon/aws/hooks/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import asyncio
import fnmatch
import gzip as gz
import inspect
import logging
import os
import re
Expand All @@ -36,7 +37,7 @@
from io import BytesIO
from pathlib import Path
from tempfile import NamedTemporaryFile, gettempdir
from typing import TYPE_CHECKING, Any, Callable
from typing import TYPE_CHECKING, Any, AsyncIterator, Callable
from urllib.parse import urlsplit
from uuid import uuid4

Expand Down Expand Up @@ -65,38 +66,15 @@ def provide_bucket_name(func: Callable) -> Callable:
"""Provide a bucket name taken from the connection if no bucket name has been passed to the function."""
if hasattr(func, "_unify_bucket_name_and_key_wrapped"):
logger.warning("`unify_bucket_name_and_key` should wrap `provide_bucket_name`.")
function_signature = signature(func)

@wraps(func)
def wrapper(*args, **kwargs) -> Callable:
bound_args = function_signature.bind(*args, **kwargs)

if "bucket_name" not in bound_args.arguments:
self = args[0]

if "bucket_name" in self.service_config:
bound_args.arguments["bucket_name"] = self.service_config["bucket_name"]
elif self.conn_config and self.conn_config.schema:
warnings.warn(
"s3 conn_type, and the associated schema field, is deprecated. "
"Please use aws conn_type instead, and specify `bucket_name` "
"in `service_config.s3` within `extras`.",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
bound_args.arguments["bucket_name"] = self.conn_config.schema

return func(*bound_args.args, **bound_args.kwargs)

return wrapper


def provide_bucket_name_async(func: Callable) -> Callable:
"""Provide a bucket name taken from the connection if no bucket name has been passed to the function."""
function_signature = signature(func)
if "bucket_name" not in function_signature.parameters:
raise RuntimeError(
"Decorator provide_bucket_name should only wrap a function with param 'bucket_name'."
)

@wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
# todo: raise immediately if func has no bucket_name arg
async def maybe_add_bucket_name(*args, **kwargs):
bound_args = function_signature.bind(*args, **kwargs)

if "bucket_name" not in bound_args.arguments:
Expand All @@ -105,8 +83,46 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any:
connection = await sync_to_async(self.get_connection)(self.aws_conn_id)
if connection.schema:
bound_args.arguments["bucket_name"] = connection.schema
return bound_args

if inspect.iscoroutinefunction(func):

@wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
bound_args = await maybe_add_bucket_name(*args, **kwargs)
print(f"invoking async function {func=}")
return await func(*bound_args.args, **bound_args.kwargs)

elif inspect.isasyncgenfunction(func):

@wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
bound_args = await maybe_add_bucket_name(*args, **kwargs)
async for thing in func(*bound_args.args, **bound_args.kwargs):
yield thing

else:

@wraps(func)
def wrapper(*args, **kwargs) -> Callable:
bound_args = function_signature.bind(*args, **kwargs)

if "bucket_name" not in bound_args.arguments:
self = args[0]

if "bucket_name" in self.service_config:
bound_args.arguments["bucket_name"] = self.service_config["bucket_name"]
elif self.conn_config and self.conn_config.schema:
warnings.warn(
"s3 conn_type, and the associated schema field, is deprecated. "
"Please use aws conn_type instead, and specify `bucket_name` "
"in `service_config.s3` within `extras`.",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
bound_args.arguments["bucket_name"] = self.conn_config.schema

return await func(*bound_args.args, **bound_args.kwargs)
return func(*bound_args.args, **bound_args.kwargs)

return wrapper

Expand Down Expand Up @@ -400,8 +416,8 @@ def list_prefixes(

return prefixes

@provide_bucket_name_async
@unify_bucket_name_and_key
@provide_bucket_name
async def get_head_object_async(
self, client: AioBaseClient, key: str, bucket_name: str | None = None
) -> dict[str, Any] | None:
Expand Down Expand Up @@ -462,10 +478,10 @@ async def list_prefixes_async(

return prefixes

@provide_bucket_name_async
@provide_bucket_name
async def get_file_metadata_async(
self, client: AioBaseClient, bucket_name: str, key: str | None = None
) -> list[Any]:
) -> AsyncIterator[Any]:
"""
Get a list of files that a key matching a wildcard expression exists in a bucket asynchronously.

Expand All @@ -477,11 +493,10 @@ async def get_file_metadata_async(
delimiter = ""
paginator = client.get_paginator("list_objects_v2")
response = paginator.paginate(Bucket=bucket_name, Prefix=prefix, Delimiter=delimiter)
files = []
async for page in response:
if "Contents" in page:
files += page["Contents"]
return files
for row in page["Contents"]:
yield row

async def _check_key_async(
self,
Expand All @@ -506,21 +521,16 @@ async def _check_key_async(
"""
bucket_name, key = self.get_s3_bucket_key(bucket_val, key, "bucket_name", "bucket_key")
if wildcard_match:
keys = await self.get_file_metadata_async(client, bucket_name, key)
key_matches = [k for k in keys if fnmatch.fnmatch(k["Key"], key)]
if not key_matches:
return False
elif use_regex:
keys = await self.get_file_metadata_async(client, bucket_name)
key_matches = [k for k in keys if re.match(pattern=key, string=k["Key"])]
if not key_matches:
return False
else:
obj = await self.get_head_object_async(client, key, bucket_name)
if obj is None:
return False

return True
async for k in self.get_file_metadata_async(client, bucket_name, key):
if fnmatch.fnmatch(k["Key"], key):
return True
return False
if use_regex:
async for k in self.get_file_metadata_async(client, bucket_name):
if re.match(pattern=key, string=k["Key"]):
return True
return False
return bool(await self.get_head_object_async(client, key, bucket_name))

async def check_key_async(
self,
Expand Down
122 changes: 78 additions & 44 deletions tests/providers/amazon/aws/hooks/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import re
from datetime import datetime as std_datetime, timezone
from unittest import mock, mock as async_mock
from unittest.mock import MagicMock, Mock, patch
from unittest.mock import AsyncMock, MagicMock, Mock, patch
from urllib.parse import parse_qs

import boto3
Expand Down Expand Up @@ -428,8 +428,9 @@ async def test_s3_key_hook_get_file_metadata_async(self, mock_client):

s3_hook_async = S3Hook(client_type="S3")
mock_client.get_paginator = mock.Mock(return_value=mock_paginator)
task = await s3_hook_async.get_file_metadata_async(mock_client, "test_bucket", "test*")
assert task == [
keys = [x async for x in s3_hook_async.get_file_metadata_async(mock_client, "test_bucket", "test*")]

assert keys == [
{"Key": "test_key", "ETag": "etag1", "LastModified": datetime(2020, 8, 14, 17, 19, 34)},
{"Key": "test_key2", "ETag": "etag2", "LastModified": datetime(2020, 8, 14, 17, 19, 34)},
]
Expand Down Expand Up @@ -632,64 +633,90 @@ async def test_s3_prefix_sensor_hook_check_for_prefix_async(

@pytest.mark.asyncio
@async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_s3_bucket_key")
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.get_head_object_async")
@async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.async_conn")
async def test__check_key_async_without_wildcard_match(
self, mock_client, mock_head_object, mock_get_bucket_key
):
async def test__check_key_async_without_wildcard_match(self, mock_get_conn, mock_get_bucket_key):
"""Test _check_key_async function without using wildcard_match"""
mock_get_bucket_key.return_value = "test_bucket", "test.txt"
mock_head_object.return_value = {"ContentLength": 0}
mock_client = mock_get_conn.return_value
mock_client.head_object = AsyncMock(return_value={"ContentLength": 0})
s3_hook_async = S3Hook(client_type="S3", resource_type="S3")
response = await s3_hook_async._check_key_async(
mock_client.return_value, "test_bucket", False, "s3://test_bucket/file/test.txt"
mock_client, "test_bucket", False, "s3://test_bucket/file/test.txt"
)
assert response is True

@pytest.mark.asyncio
@async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_s3_bucket_key")
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.get_head_object_async")
@async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.async_conn")
async def test_s3__check_key_async_without_wildcard_match_and_get_none(
self, mock_client, mock_head_object, mock_get_bucket_key
self, mock_get_conn, mock_get_bucket_key
):
"""Test _check_key_async function when get head object returns none"""
mock_get_bucket_key.return_value = "test_bucket", "test.txt"
mock_head_object.return_value = None
s3_hook_async = S3Hook(client_type="S3", resource_type="S3")
mock_client = mock_get_conn.return_value
mock_client.head_object = AsyncMock(return_value=None)
response = await s3_hook_async._check_key_async(
mock_client.return_value, "test_bucket", False, "s3://test_bucket/file/test.txt"
mock_client, "test_bucket", False, "s3://test_bucket/file/test.txt"
)
assert response is False

# @async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_s3_bucket_key")
@pytest.mark.asyncio
@async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_s3_bucket_key")
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.get_file_metadata_async")
@async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.async_conn")
async def test_s3__check_key_async_with_wildcard_match(
self, mock_client, mock_get_file_metadata, mock_get_bucket_key
):
@pytest.mark.parametrize(
"contents, result",
[
(
[
{
"Key": "test/example_s3_test_file.txt",
"ETag": "etag1",
"LastModified": datetime(2020, 8, 14, 17, 19, 34),
"Size": 0,
},
{
"Key": "test_key2",
"ETag": "etag2",
"LastModified": datetime(2020, 8, 14, 17, 19, 34),
"Size": 0,
},
],
True,
),
(
[
{
"Key": "test/example_aeoua.txt",
"ETag": "etag1",
"LastModified": datetime(2020, 8, 14, 17, 19, 34),
"Size": 0,
},
{
"Key": "test_key2",
"ETag": "etag2",
"LastModified": datetime(2020, 8, 14, 17, 19, 34),
"Size": 0,
},
],
False,
),
],
)
async def test_s3__check_key_async_with_wildcard_match(self, mock_get_conn, contents, result):
"""Test _check_key_async function"""
mock_get_bucket_key.return_value = "test_bucket", "test"
mock_get_file_metadata.return_value = [
{
"Key": "test_key",
"ETag": "etag1",
"LastModified": datetime(2020, 8, 14, 17, 19, 34),
"Size": 0,
},
{
"Key": "test_key2",
"ETag": "etag2",
"LastModified": datetime(2020, 8, 14, 17, 19, 34),
"Size": 0,
},
]
client = mock_get_conn.return_value
paginator = client.get_paginator.return_value
r = paginator.paginate.return_value
r.__aiter__.return_value = [{"Contents": contents}]
s3_hook_async = S3Hook(client_type="S3", resource_type="S3")
response = await s3_hook_async._check_key_async(
mock_client.return_value, "test_bucket", True, "test/example_s3_test_file.txt"
client=client,
bucket_val="test_bucket",
wildcard_match=True,
key="test/example_s3_test_file.txt",
)
assert response is False
assert response is result

@pytest.mark.parametrize(
"key, pattern, expected",
Expand All @@ -701,24 +728,31 @@ async def test_s3__check_key_async_with_wildcard_match(
)
@pytest.mark.asyncio
@async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_s3_bucket_key")
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.get_file_metadata_async")
@async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.async_conn")
async def test__check_key_async_with_use_regex(
self, mock_client, mock_get_file_metadata, mock_get_bucket_key, key, pattern, expected
self, mock_get_conn, mock_get_bucket_key, key, pattern, expected
):
"""Match AWS S3 key with regex expression"""
mock_get_bucket_key.return_value = "test_bucket", pattern
mock_get_file_metadata.return_value = [
client = mock_get_conn.return_value
paginator = client.get_paginator.return_value
r = paginator.paginate.return_value
r.__aiter__.return_value = [
{
"Key": key,
"ETag": "etag1",
"LastModified": datetime(2020, 8, 14, 17, 19, 34),
"Size": 0,
},
"Contents": [
{
"Key": key,
"ETag": "etag1",
"LastModified": datetime(2020, 8, 14, 17, 19, 34),
"Size": 0,
},
]
}
]

s3_hook_async = S3Hook(client_type="S3", resource_type="S3")
response = await s3_hook_async._check_key_async(
client=mock_client.return_value,
client=client,
bucket_val="test_bucket",
wildcard_match=False,
key=pattern,
Expand Down