Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion airflow/providers/amazon/aws/operators/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,7 @@ class S3FileTransformOperator(BaseOperator):
:param dest_s3_key: The key to be written from S3. (templated)
:param transform_script: location of the executable transformation script
:param select_expression: S3 Select expression
:param select_expr_serialization_config: A dictionary that contains input and output serialization configurations for S3 Select.
:param script_args: arguments for transformation script (templated)
:param source_aws_conn_id: source s3 connection
:param source_verify: Whether or not to verify SSL certificates for S3 connection.
Expand Down Expand Up @@ -641,6 +642,7 @@ def __init__(
dest_s3_key: str,
transform_script: str | None = None,
select_expression=None,
select_expr_serialization_config: dict[str, dict[str, dict]] | None = None,
script_args: Sequence[str] | None = None,
source_aws_conn_id: str | None = "aws_default",
source_verify: bool | str | None = None,
Expand All @@ -659,6 +661,7 @@ def __init__(
self.replace = replace
self.transform_script = transform_script
self.select_expression = select_expression
self.select_expr_serialization_config = select_expr_serialization_config or {}
self.script_args = script_args or []
self.output_encoding = sys.getdefaultencoding()

Expand All @@ -678,7 +681,14 @@ def execute(self, context: Context):
self.log.info("Dumping S3 file %s contents to local file %s", self.source_s3_key, f_source.name)

if self.select_expression is not None:
content = source_s3.select_key(key=self.source_s3_key, expression=self.select_expression)
input_serialization = self.select_expr_serialization_config.get("input_serialization")
output_serialization = self.select_expr_serialization_config.get("output_serialization")
content = source_s3.select_key(
key=self.source_s3_key,
expression=self.select_expression,
input_serialization=input_serialization,
output_serialization=output_serialization,
)
f_source.write(content.encode("utf-8"))
else:
source_s3_key_object.download_fileobj(Fileobj=f_source)
Expand Down
42 changes: 41 additions & 1 deletion tests/providers/amazon/aws/operators/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,8 @@ def test_execute_with_transform_script_args(self, mock_popen):
def test_execute_with_select_expression(self, mock_select_key):
input_path, output_path = self.s3_paths()
select_expression = "SELECT * FROM s3object s"
input_serialization = None
output_serialization = None

op = S3FileTransformOperator(
source_s3_key=input_path,
Expand All @@ -294,7 +296,45 @@ def test_execute_with_select_expression(self, mock_select_key):
)
op.execute(None)

mock_select_key.assert_called_once_with(key=input_path, expression=select_expression)
mock_select_key.assert_called_once_with(
key=input_path,
expression=select_expression,
input_serialization=input_serialization,
output_serialization=output_serialization,
)

conn = boto3.client("s3")
result = conn.get_object(Bucket=self.bucket, Key=self.output_key)
assert self.content == result["Body"].read()

@mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.select_key", return_value="input")
@mock_aws
def test_execute_with_select_expression_and_serialization_config(self, mock_select_key):
input_path, output_path = self.s3_paths()
select_expression = "SELECT * FROM s3object s"
select_expr_serialization_config = {
"input_serialization": {"CSV": {}},
"output_serialization": {"CSV": {}},
}

op = S3FileTransformOperator(
source_s3_key=input_path,
dest_s3_key=output_path,
select_expression=select_expression,
select_expr_serialization_config=select_expr_serialization_config,
replace=True,
task_id="task_id",
)
op.execute(None)

input_serialization = select_expr_serialization_config.get("input_serialization")
output_serialization = select_expr_serialization_config.get("output_serialization")
mock_select_key.assert_called_once_with(
key=input_path,
expression=select_expression,
input_serialization=input_serialization,
output_serialization=output_serialization,
)

conn = boto3.client("s3")
result = conn.get_object(Bucket=self.bucket, Key=self.output_key)
Expand Down