diff --git a/airflow/providers/amazon/aws/operators/s3.py b/airflow/providers/amazon/aws/operators/s3.py index a3f9c9245e041..f2733495efc05 100644 --- a/airflow/providers/amazon/aws/operators/s3.py +++ b/airflow/providers/amazon/aws/operators/s3.py @@ -610,6 +610,7 @@ class S3FileTransformOperator(BaseOperator): :param dest_s3_key: The key to be written from S3. (templated) :param transform_script: location of the executable transformation script :param select_expression: S3 Select expression + :param select_expr_serialization_config: A dictionary that contains input and output serialization configurations for S3 Select. :param script_args: arguments for transformation script (templated) :param source_aws_conn_id: source s3 connection :param source_verify: Whether or not to verify SSL certificates for S3 connection. @@ -641,6 +642,7 @@ def __init__( dest_s3_key: str, transform_script: str | None = None, select_expression=None, + select_expr_serialization_config: dict[str, dict[str, dict]] | None = None, script_args: Sequence[str] | None = None, source_aws_conn_id: str | None = "aws_default", source_verify: bool | str | None = None, @@ -659,6 +661,7 @@ def __init__( self.replace = replace self.transform_script = transform_script self.select_expression = select_expression + self.select_expr_serialization_config = select_expr_serialization_config or {} self.script_args = script_args or [] self.output_encoding = sys.getdefaultencoding() @@ -678,7 +681,14 @@ def execute(self, context: Context): self.log.info("Dumping S3 file %s contents to local file %s", self.source_s3_key, f_source.name) if self.select_expression is not None: - content = source_s3.select_key(key=self.source_s3_key, expression=self.select_expression) + input_serialization = self.select_expr_serialization_config.get("input_serialization") + output_serialization = self.select_expr_serialization_config.get("output_serialization") + content = source_s3.select_key( + key=self.source_s3_key, + expression=self.select_expression, + input_serialization=input_serialization, + output_serialization=output_serialization, + ) f_source.write(content.encode("utf-8")) else: source_s3_key_object.download_fileobj(Fileobj=f_source) diff --git a/tests/providers/amazon/aws/operators/test_s3.py b/tests/providers/amazon/aws/operators/test_s3.py index e2cf5d3543eaa..5e4bbffbd07d9 100644 --- a/tests/providers/amazon/aws/operators/test_s3.py +++ b/tests/providers/amazon/aws/operators/test_s3.py @@ -284,6 +284,8 @@ def test_execute_with_transform_script_args(self, mock_popen): def test_execute_with_select_expression(self, mock_select_key): input_path, output_path = self.s3_paths() select_expression = "SELECT * FROM s3object s" + input_serialization = None + output_serialization = None op = S3FileTransformOperator( source_s3_key=input_path, @@ -294,7 +296,45 @@ def test_execute_with_select_expression(self, mock_select_key): ) op.execute(None) - mock_select_key.assert_called_once_with(key=input_path, expression=select_expression) + mock_select_key.assert_called_once_with( + key=input_path, + expression=select_expression, + input_serialization=input_serialization, + output_serialization=output_serialization, + ) + + conn = boto3.client("s3") + result = conn.get_object(Bucket=self.bucket, Key=self.output_key) + assert self.content == result["Body"].read() + + @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.select_key", return_value="input") + @mock_aws + def test_execute_with_select_expression_and_serialization_config(self, mock_select_key): + input_path, output_path = self.s3_paths() + select_expression = "SELECT * FROM s3object s" + select_expr_serialization_config = { + "input_serialization": {"CSV": {}}, + "output_serialization": {"CSV": {}}, + } + + op = S3FileTransformOperator( + source_s3_key=input_path, + dest_s3_key=output_path, + select_expression=select_expression, + select_expr_serialization_config=select_expr_serialization_config, + replace=True, + task_id="task_id", + ) + op.execute(None) + + input_serialization = select_expr_serialization_config.get("input_serialization") + output_serialization = select_expr_serialization_config.get("output_serialization") + mock_select_key.assert_called_once_with( + key=input_path, + expression=select_expression, + input_serialization=input_serialization, + output_serialization=output_serialization, + ) conn = boto3.client("s3") result = conn.get_object(Bucket=self.bucket, Key=self.output_key)