diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 99b43df50a19f..cb7e06b9934a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -146,6 +146,8 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String] override protected def withNewChildInternal(newChild: Expression): GetStructField = copy(child = newChild) + + def metadata: Metadata = childSchema(ordinal).metadata } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index d5df6a12aa45b..47cdf21a8729f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -168,11 +168,8 @@ case class Alias(child: Expression, name: String)( override def metadata: Metadata = { explicitMetadata.getOrElse { child match { - case named: NamedExpression => - val builder = new MetadataBuilder().withMetadata(named.metadata) - nonInheritableMetadataKeys.foreach(builder.remove) - builder.build() - + case named: NamedExpression => removeNonInheritableMetadata(named.metadata) + case structField: GetStructField => removeNonInheritableMetadata(structField.metadata) case _ => Metadata.empty } } @@ -207,6 +204,12 @@ case class Alias(child: Expression, name: String)( "" } + private def removeNonInheritableMetadata(metadata: Metadata): Metadata = { + val builder = new MetadataBuilder().withMetadata(metadata) + nonInheritableMetadataKeys.foreach(builder.remove) + builder.build() + } + override def toString: String = s"$child AS $name#${exprId.id}$typeSuffix$delaySuffix" override protected final def otherCopyArgs: Seq[AnyRef] = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NamedExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NamedExpressionSuite.scala index f6cc19abaf9df..3e6f40f3b1ca5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NamedExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NamedExpressionSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.{IntegerType, MetadataBuilder, StringType, StructField, StructType} class NamedExpressionSuite extends SparkFunSuite { @@ -51,4 +51,17 @@ class NamedExpressionSuite extends SparkFunSuite { val attr13 = UnresolvedAttribute("`a.b`") assert(attr13.sql === "`a.b`") } + + test("SPARK-34805: non inheritable metadata should be removed from child struct in Alias") { + val nonInheritableMetadataKey = "non-inheritable-key" + val metadata = new MetadataBuilder() + .putString(nonInheritableMetadataKey, "value1") + .putString("key", "value2") + .build() + val structType = StructType(Seq(StructField("value", StringType, metadata = metadata))) + val alias = Alias(GetStructField(AttributeReference("a", structType)(), 0), "my-alias")( + nonInheritableMetadataKeys = Seq(nonInheritableMetadataKey)) + assert(!alias.metadata.contains(nonInheritableMetadataKey)) + assert(alias.metadata.contains("key")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 995bf5d903ad4..b392b7536f5f3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -137,6 +137,19 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { assert(newCol.expr.asInstanceOf[NamedExpression].metadata.getString("key") === "value") } + test("SPARK-34805: as propagates metadata from nested column") { + val metadata = new MetadataBuilder + metadata.putString("key", "value") + val df = spark.createDataFrame(sparkContext.emptyRDD[Row], + StructType(Seq( + StructField("parent", StructType(Seq( + StructField("child", StringType, metadata = metadata.build()) + )))) + )) + val newCol = df("parent.child") + assert(newCol.expr.asInstanceOf[NamedExpression].metadata.getString("key") === "value") + } + test("collect on column produced by a binary operator") { val df = Seq((1, 2, 3)).toDF("a", "b", "c") checkAnswer(df.select(df("a") + df("b")), Seq(Row(3)))