diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md index ce655afeb4a9a..aad41233b2661 100644 --- a/docs/sql-migration-guide.md +++ b/docs/sql-migration-guide.md @@ -96,6 +96,8 @@ license: | - In Spark 3.2, `FloatType` is mapped to `FLOAT` in MySQL. Prior to this, it used to be mapped to `REAL`, which is by default a synonym to `DOUBLE PRECISION` in MySQL. - In Spark 3.2, the query executions triggered by `DataFrameWriter` are always named `command` when being sent to `QueryExecutionListener`. In Spark 3.1 and earlier, the name is one of `save`, `insertInto`, `saveAsTable`, `create`, `append`, `overwrite`, `overwritePartitions`, `replace`. + + - In Spark 3.2, `Dataset.unionByName` with `allowMissingColumns` set to true will add missing nested fields to the end of structs. In Spark 3.1, nested struct fields are sorted alphabetically. ## Upgrading from Spark SQL 3.0 to 3.1 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala index 08cc61f819004..8cc3547e8316e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala @@ -20,137 +20,66 @@ package org.apache.spark.sql.catalyst.analysis import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.optimizer.{CombineUnions, OptimizeUpdateFields} +import org.apache.spark.sql.catalyst.optimizer.{CombineUnions} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern.UNION import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types._ import org.apache.spark.sql.util.SchemaUtils -import org.apache.spark.unsafe.types.UTF8String /** * Resolves different children of Union to a common set of columns. */ object ResolveUnion extends Rule[LogicalPlan] { /** - * This method sorts columns recursively in a struct expression based on column names. + * Adds missing fields recursively into given `col` expression, based on the expected struct + * fields from merging the two schemas. This is called by `compareAndAddFields` when we find two + * struct columns with same name but different nested fields. This method will recursively + * return a new struct with all of the expected fields, adding null values when `col` doesn't + * already contain them. Currently we don't support merging structs nested inside of arrays + * or maps. */ - private def sortStructFields(expr: Expression): Expression = { - val existingExprs = expr.dataType.asInstanceOf[StructType].fieldNames.zipWithIndex.map { - case (name, i) => - val fieldExpr = GetStructField(KnownNotNull(expr), i) - if (fieldExpr.dataType.isInstanceOf[StructType]) { - (name, sortStructFields(fieldExpr)) - } else { - (name, fieldExpr) - } - }.sortBy(_._1).flatMap(pair => Seq(Literal(pair._1), pair._2)) - - val newExpr = CreateNamedStruct(existingExprs) - if (expr.nullable) { - If(IsNull(expr), Literal(null, newExpr.dataType), newExpr) - } else { - newExpr - } - } - - /** - * Assumes input expressions are field expression of `CreateNamedStruct`. This method - * sorts the expressions based on field names. - */ - private def sortFieldExprs(fieldExprs: Seq[Expression]): Seq[Expression] = { - fieldExprs.grouped(2).map { e => - Seq(e.head, e.last) - }.toSeq.sortBy { pair => - assert(pair.head.isInstanceOf[Literal]) - pair.head.eval().asInstanceOf[UTF8String].toString - }.flatten - } - - /** - * This helper method sorts fields in a `UpdateFields` expression by field name. - */ - private def sortStructFieldsInWithFields(expr: Expression): Expression = expr transformUp { - case u: UpdateFields if u.resolved => - u.evalExpr match { - case i @ If(IsNull(_), _, CreateNamedStruct(fieldExprs)) => - val sorted = sortFieldExprs(fieldExprs) - val newStruct = CreateNamedStruct(sorted) - i.copy(trueValue = Literal(null, newStruct.dataType), falseValue = newStruct) - case CreateNamedStruct(fieldExprs) => - val sorted = sortFieldExprs(fieldExprs) - val newStruct = CreateNamedStruct(sorted) - newStruct - case other => - throw new IllegalStateException(s"`UpdateFields` has incorrect expression: $other. " + - "Please file a bug report with this error message, stack trace, and the query.") - } - } - - /** - * Adds missing fields recursively into given `col` expression, based on the target `StructType`. - * This is called by `compareAndAddFields` when we find two struct columns with same name but - * different nested fields. This method will find out the missing nested fields from `col` to - * `target` struct and add these missing nested fields. Currently we don't support finding out - * missing nested fields of struct nested in array or struct nested in map. - */ - private def addFields(col: NamedExpression, target: StructType): Expression = { + private def addFields(col: Expression, targetType: StructType): Expression = { assert(col.dataType.isInstanceOf[StructType], "Only support StructType.") val resolver = conf.resolver - val missingFieldsOpt = - StructType.findMissingFields(col.dataType.asInstanceOf[StructType], target, resolver) - - // We need to sort columns in result, because we might add another column in other side. - // E.g., we want to union two structs "a int, b long" and "a int, c string". - // If we don't sort, we will have "a int, b long, c string" and - // "a int, c string, b long", which are not compatible. - if (missingFieldsOpt.isEmpty) { - sortStructFields(col) - } else { - missingFieldsOpt.map { s => - val struct = addFieldsInto(col, s.fields) - // Combines `WithFields`s to reduce expression tree. - val reducedStruct = struct.transformUp(OptimizeUpdateFields.optimizeUpdateFields) - val sorted = sortStructFieldsInWithFields(reducedStruct) - sorted - }.get + val colType = col.dataType.asInstanceOf[StructType] + + val newStructFields = mutable.ArrayBuffer.empty[Expression] + + val targetStructFields = targetType.fields.foreach { expectedField => + val currentField = colType.fields.find(f => resolver(f.name, expectedField.name)) + + val newExpression = (currentField, expectedField.dataType) match { + case (Some(cf), expectedType: StructType) if cf.dataType.isInstanceOf[StructType] => + val extractedValue = ExtractValue(col, Literal(cf.name), resolver) + val combinedStruct = addFields(extractedValue, expectedType) + if (extractedValue.nullable) { + If(IsNull(extractedValue), + Literal(null, combinedStruct.dataType), + combinedStruct) + } else { + combinedStruct + } + case (Some(cf), _) => + ExtractValue(col, Literal(cf.name), resolver) + case (None, expectedType) => + Literal(null, expectedType) + } + newStructFields ++= Literal(expectedField.name) :: newExpression :: Nil } - } - /** - * Adds missing fields recursively into given `col` expression. The missing fields are given - * in `fields`. For example, given `col` as "z struct, x int", and `fields` is - * "z struct, w string". This method will add a nested `z.w` field and a top-level - * `w` field to `col` and fill null values for them. Note that because we might also add missing - * fields at other side of Union, we must make sure corresponding attributes at two sides have - * same field order in structs, so when we adding missing fields, we will sort the fields based on - * field names. So the data type of returned expression will be - * "w string, x int, z struct". - */ - private def addFieldsInto( - col: Expression, - fields: Seq[StructField]): Expression = { - fields.foldLeft(col) { case (currCol, field) => - field.dataType match { - case st: StructType => - val resolver = conf.resolver - val colField = currCol.dataType.asInstanceOf[StructType] - .find(f => resolver(f.name, field.name)) - if (colField.isEmpty) { - // The whole struct is missing. Add a null. - UpdateFields(currCol, field.name, Literal(null, st)) - } else { - UpdateFields(currCol, field.name, - addFieldsInto(ExtractValue(currCol, Literal(field.name), resolver), st.fields)) - } - case dt => - UpdateFields(currCol, field.name, Literal(null, dt)) + colType.fields + .filter(f => targetType.fields.find(tf => resolver(f.name, tf.name)).isEmpty) + .foreach { f => + newStructFields ++= Literal(f.name) :: ExtractValue(col, Literal(f.name), resolver) :: Nil } - } + + CreateNamedStruct(newStructFields.toSeq) } + /** * This method will compare right to left plan's outputs. If there is one struct attribute * at right side has same name with left side struct attribute, but two structs are not the @@ -208,13 +137,11 @@ object ResolveUnion extends Rule[LogicalPlan] { left: LogicalPlan, right: LogicalPlan, allowMissingCol: Boolean): LogicalPlan = { - val rightOutputAttrs = right.output - // Builds a project list for `right` based on `left` output names val (rightProjectList, aliased) = compareAndAddFields(left, right, allowMissingCol) // Delegates failure checks to `CheckAnalysis` - val notFoundAttrs = rightOutputAttrs.diff(rightProjectList ++ aliased) + val notFoundAttrs = right.output.diff(rightProjectList ++ aliased) val rightChild = Project(rightProjectList ++ notFoundAttrs, right) // Builds a project for `logicalPlan` based on `right` output names, if allowing @@ -230,6 +157,7 @@ object ResolveUnion extends Rule[LogicalPlan] { } else { left } + Union(leftChild, rightChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala index 9aa2766dd3e8c..76be8328bea4a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala @@ -38,7 +38,7 @@ object SchemaPruning extends SQLConfHelper { // original schema val mergedSchema = requestedRootFields .map { root: RootField => StructType(Array(root.field)) } - .reduceLeft(_ merge _) + .reduceLeft((left, right) => left.merge(right, resolver)) val mergedDataSchema = StructType(dataSchema.map(d => mergedSchema.find(m => resolver(m.name, d.name)).getOrElse(d))) // Sort the fields of mergedDataSchema according to their order in dataSchema, @@ -113,7 +113,7 @@ object SchemaPruning extends SQLConfHelper { // this optional root field too. val rootFieldType = StructType(Array(root.field)) val optFieldType = StructType(Array(opt.field)) - val merged = optFieldType.merge(rootFieldType) + val merged = optFieldType.merge(rootFieldType, conf.resolver) merged.sameType(optFieldType) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 9438a444d6db3..2397f5919e7da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -307,7 +307,7 @@ case class Union( children.map(_.output).transpose.map { attrs => val firstAttr = attrs.head val nullable = attrs.exists(_.nullable) - val newDt = attrs.map(_.dataType).reduce(StructType.merge) + val newDt = attrs.map(_.dataType).reduce(StructType.merge(conf.resolver)) if (firstAttr.dataType == newDt) { firstAttr.withNullability(nullable) } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 8ff0536c2f3a0..f4d7a38c18d5d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -483,8 +483,8 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru * 4. Otherwise, `this` and `that` are considered as conflicting schemas and an exception would be * thrown. */ - private[sql] def merge(that: StructType): StructType = - StructType.merge(this, that).asInstanceOf[StructType] + private[sql] def merge(that: StructType, resolver: Resolver): StructType = + StructType.merge(resolver)(this, that).asInstanceOf[StructType] override private[spark] def asNullable: StructType = { val newFields = fields.map { @@ -555,32 +555,31 @@ object StructType extends AbstractDataType { case _ => dt } - private[sql] def merge(left: DataType, right: DataType): DataType = + private[sql] def merge(resolver: Resolver)(left: DataType, right: DataType): DataType = (left, right) match { case (ArrayType(leftElementType, leftContainsNull), ArrayType(rightElementType, rightContainsNull)) => ArrayType( - merge(leftElementType, rightElementType), + merge(resolver)(leftElementType, rightElementType), leftContainsNull || rightContainsNull) case (MapType(leftKeyType, leftValueType, leftContainsNull), MapType(rightKeyType, rightValueType, rightContainsNull)) => MapType( - merge(leftKeyType, rightKeyType), - merge(leftValueType, rightValueType), + merge(resolver)(leftKeyType, rightKeyType), + merge(resolver)(leftValueType, rightValueType), leftContainsNull || rightContainsNull) case (StructType(leftFields), StructType(rightFields)) => val newFields = mutable.ArrayBuffer.empty[StructField] - val rightMapped = fieldsMap(rightFields) leftFields.foreach { case leftField @ StructField(leftName, leftType, leftNullable, _) => - rightMapped.get(leftName) + rightFields.find(f => resolver(leftName, f.name)) .map { case rightField @ StructField(rightName, rightType, rightNullable, _) => try { leftField.copy( - dataType = merge(leftType, rightType), + dataType = merge(resolver)(leftType, rightType), nullable = leftNullable || rightNullable) } catch { case NonFatal(e) => @@ -593,12 +592,9 @@ object StructType extends AbstractDataType { .foreach(newFields += _) } - val leftMapped = fieldsMap(leftFields) rightFields - .filterNot(f => leftMapped.get(f.name).nonEmpty) - .foreach { f => - newFields += f - } + .filter(f => leftFields.find(lf => resolver(f.name, lf.name)).isEmpty) + .foreach(newFields += _) StructType(newFields.toSeq) @@ -634,39 +630,4 @@ object StructType extends AbstractDataType { fields.foreach(s => map.put(s.name, s)) map } - - /** - * Returns a `StructType` that contains missing fields recursively from `source` to `target`. - * Note that this doesn't support looking into array type and map type recursively. - */ - def findMissingFields( - source: StructType, - target: StructType, - resolver: Resolver): Option[StructType] = { - def bothStructType(dt1: DataType, dt2: DataType): Boolean = - dt1.isInstanceOf[StructType] && dt2.isInstanceOf[StructType] - - val newFields = mutable.ArrayBuffer.empty[StructField] - - target.fields.foreach { field => - val found = source.fields.find(f => resolver(field.name, f.name)) - if (found.isEmpty) { - // Found a missing field in `source`. - newFields += field - } else if (bothStructType(found.get.dataType, field.dataType) && - !found.get.dataType.sameType(field.dataType)) { - // Found a field with same name, but different data type. - findMissingFields(found.get.dataType.asInstanceOf[StructType], - field.dataType.asInstanceOf[StructType], resolver).map { missingType => - newFields += found.get.copy(dataType = missingType) - } - } - } - - if (newFields.isEmpty) { - None - } else { - Some(StructType(newFields.toSeq)) - } - } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 3c85eef76612a..ba44cd3c75d90 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -22,6 +22,7 @@ import com.fasterxml.jackson.core.JsonParseException import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.StringUtils.StringConcat +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.DataTypeTestUtils.dayTimeIntervalTypes class DataTypeSuite extends SparkFunSuite { @@ -153,7 +154,7 @@ class DataTypeSuite extends SparkFunSuite { StructField("b", LongType) :: Nil) val message = intercept[SparkException] { - left.merge(right) + left.merge(right, SQLConf.get.resolver) }.getMessage assert(message.equals("Failed to merge fields 'b' and 'b'. " + "Failed to merge incompatible data types float and bigint")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala index 820f32614e04a..29e5c9974dc7f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala @@ -150,95 +150,36 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { assert(fromDDL(interval).toDDL === interval) } - test("find missing (nested) fields") { - val schema = StructType.fromDDL("c1 INT, c2 STRUCT>") + test("SPARK-35290: Struct merging case insensitive") { + val schema1 = StructType.fromDDL("a1 INT, a2 STRING, nested STRUCT") + val schema2 = StructType.fromDDL("A2 STRING, a3 DOUBLE, nested STRUCT") val resolver = SQLConf.get.resolver - val source1 = StructType.fromDDL("c1 INT") - val missing1 = StructType.fromDDL("c2 STRUCT>") - assert(StructType.findMissingFields(source1, schema, resolver) - .exists(_.sameType(missing1))) + assert(schema1.merge(schema2, resolver) === StructType.fromDDL( + "a1 INT, a2 STRING, nested STRUCT, a3 DOUBLE" + )) - val source2 = StructType.fromDDL("c1 INT, c3 STRING") - val missing2 = StructType.fromDDL("c2 STRUCT>") - assert(StructType.findMissingFields(source2, schema, resolver) - .exists(_.sameType(missing2))) - - val source3 = StructType.fromDDL("c1 INT, c2 STRUCT") - val missing3 = StructType.fromDDL("c2 STRUCT>") - assert(StructType.findMissingFields(source3, schema, resolver) - .exists(_.sameType(missing3))) - - val source4 = StructType.fromDDL("c1 INT, c2 STRUCT>") - val missing4 = StructType.fromDDL("c2 STRUCT>") - assert(StructType.findMissingFields(source4, schema, resolver) - .exists(_.sameType(missing4))) + assert(schema2.merge(schema1, resolver) === StructType.fromDDL( + "A2 STRING, a3 DOUBLE, nested STRUCT, a1 INT" + )) } - test("find missing (nested) fields: array and map") { - val resolver = SQLConf.get.resolver - - val schemaWithArray = StructType.fromDDL("c1 INT, c2 ARRAY>") - val source5 = StructType.fromDDL("c1 INT") - val missing5 = StructType.fromDDL("c2 ARRAY>") - assert( - StructType.findMissingFields(source5, schemaWithArray, resolver) - .exists(_.sameType(missing5))) - - val schemaWithMap1 = StructType.fromDDL( - "c1 INT, c2 MAP, STRING>, c3 LONG") - val source6 = StructType.fromDDL("c1 INT, c3 LONG") - val missing6 = StructType.fromDDL("c2 MAP, STRING>") - assert( - StructType.findMissingFields(source6, schemaWithMap1, resolver) - .exists(_.sameType(missing6))) - - val schemaWithMap2 = StructType.fromDDL( - "c1 INT, c2 MAP>, c3 STRING") - val source7 = StructType.fromDDL("c1 INT, c3 STRING") - val missing7 = StructType.fromDDL("c2 MAP>") - assert( - StructType.findMissingFields(source7, schemaWithMap2, resolver) - .exists(_.sameType(missing7))) - - // Unsupported: nested struct in array, map - val source8 = StructType.fromDDL("c1 INT, c2 ARRAY>") - // `findMissingFields` doesn't support looking into nested struct in array type. - assert(StructType.findMissingFields(source8, schemaWithArray, resolver).isEmpty) - - val source9 = StructType.fromDDL("c1 INT, c2 MAP, STRING>, c3 LONG") - // `findMissingFields` doesn't support looking into nested struct in map type. - assert(StructType.findMissingFields(source9, schemaWithMap1, resolver).isEmpty) - - val source10 = StructType.fromDDL("c1 INT, c2 MAP>, c3 STRING") - // `findMissingFields` doesn't support looking into nested struct in map type. - assert(StructType.findMissingFields(source10, schemaWithMap2, resolver).isEmpty) - } - - test("find missing (nested) fields: case sensitive cases") { + test("SPARK-35290: Struct merging case sensitive") { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { - val schema = StructType.fromDDL("c1 INT, c2 STRUCT>") + val schema1 = StructType.fromDDL("a1 INT, a2 STRING, nested STRUCT") + val schema2 = StructType.fromDDL( + "A2 STRING, a3 DOUBLE, nested STRUCT") val resolver = SQLConf.get.resolver - val source1 = StructType.fromDDL("c1 INT, C2 LONG") - val missing1 = StructType.fromDDL("c2 STRUCT>") - assert(StructType.findMissingFields(source1, schema, resolver) - .exists(_.sameType(missing1))) - - val source2 = StructType.fromDDL("c2 LONG") - val missing2 = StructType.fromDDL("c1 INT") - assert(StructType.findMissingFields(source2, schema, resolver) - .exists(_.sameType(missing2))) - - val source3 = StructType.fromDDL("c1 INT, c2 STRUCT>") - val missing3 = StructType.fromDDL("c2 STRUCT>") - assert(StructType.findMissingFields(source3, schema, resolver) - .exists(_.sameType(missing3))) - - val source4 = StructType.fromDDL("c1 INT, c2 STRUCT>") - val missing4 = StructType.fromDDL("c2 STRUCT>") - assert(StructType.findMissingFields(source4, schema, resolver) - .exists(_.sameType(missing4))) + assert(schema1.merge(schema2, resolver) === StructType.fromDDL( + "a1 INT, a2 STRING, nested STRUCT, " + + "A2 STRING, a3 DOUBLE" + )) + + assert(schema2.merge(schema1, resolver) === StructType.fromDDL( + "A2 STRING, a3 DOUBLE, nested STRUCT, " + + "a1 INT, a2 STRING" + )) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 08e59b1b00f47..b147a7a09f435 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2080,10 +2080,8 @@ class Dataset[T] private[sql]( * }}} * * Note that `allowMissingColumns` supports nested column in struct types. Missing nested columns - * of struct columns with same name will also be filled with null values. This currently does not - * support nested columns in array and map types. Note that if there is any missing nested columns - * to be filled, in order to make consistent schema between two sides of union, the nested fields - * of structs will be sorted after merging schema. + * of struct columns with the same name will also be filled with null values and added to the end + * of struct. This currently does not support nested columns in array and map types. * * @group typedrel * @since 3.1.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index b537040fe71df..022a41e5cd130 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -684,7 +684,7 @@ case class UnionExec(children: Seq[SparkPlan]) extends SparkPlan { children.map(_.output).transpose.map { attrs => val firstAttr = attrs.head val nullable = attrs.exists(_.nullable) - val newDt = attrs.map(_.dataType).reduce(StructType.merge) + val newDt = attrs.map(_.dataType).reduce(StructType.merge(conf.resolver)) if (firstAttr.dataType == newDt) { firstAttr.withNullability(nullable) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaMergeUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaMergeUtils.scala index 28097c35401c9..9299632c482bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaMergeUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaMergeUtils.scala @@ -60,6 +60,7 @@ object SchemaMergeUtils extends Logging { sparkSession.sparkContext.defaultParallelism) val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles + val resolver = sparkSession.sessionState.conf.resolver // Issues a Spark job to read Parquet/ORC schema in parallel. val partiallyMergedSchemas = @@ -80,7 +81,7 @@ object SchemaMergeUtils extends Logging { var mergedSchema = schemas.head schemas.tail.foreach { schema => try { - mergedSchema = mergedSchema.merge(schema) + mergedSchema = mergedSchema.merge(schema, resolver) } catch { case cause: SparkException => throw new SparkException( s"Failed merging schema:\n${schema.treeString}", cause) @@ -96,7 +97,7 @@ object SchemaMergeUtils extends Logging { var finalSchema = partiallyMergedSchemas.head partiallyMergedSchemas.tail.foreach { schema => try { - finalSchema = finalSchema.merge(schema) + finalSchema = finalSchema.merge(schema, resolver) } catch { case cause: SparkException => throw new SparkException( s"Failed merging schema:\n${schema.treeString}", cause) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 48e2e6e57d838..40b4fe106e02c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -385,6 +385,8 @@ object ParquetFileFormat extends Logging { private[parquet] def readSchema( footers: Seq[Footer], sparkSession: SparkSession): Option[StructType] = { + val resolver = sparkSession.sessionState.conf.resolver + val converter = new ParquetToSparkSchemaConverter( sparkSession.sessionState.conf.isParquetBinaryAsString, sparkSession.sessionState.conf.isParquetINT96AsTimestamp) @@ -429,7 +431,7 @@ object ParquetFileFormat extends Logging { } finalSchemas.reduceOption { (left, right) => - try left.merge(right) catch { case e: Throwable => + try left.merge(right, resolver) catch { case e: Throwable => throw QueryExecutionErrors.failedToMergeIncompatibleSchemasError(left, right, e) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala index 797673ae15ba8..e622528afc69d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala @@ -677,27 +677,30 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { val df1 = Seq((0, UnionClass1a(0, 1L, UnionClass2(1, "2")))).toDF("id", "a") val df2 = Seq((1, UnionClass1b(1, 2L, UnionClass3(2, 3L)))).toDF("id", "a") - val expectedSchema = "`id` INT,`a` STRUCT<`a`: INT, `b`: BIGINT, " + - "`nested`: STRUCT<`a`: INT, `b`: BIGINT, `c`: STRING>>" - var unionDf = df1.unionByName(df2, true) + assert(unionDf.schema.toDDL == + "`id` INT,`a` STRUCT<`a`: INT, `b`: BIGINT, " + + "`nested`: STRUCT<`a`: INT, `c`: STRING, `b`: BIGINT>>") checkAnswer(unionDf, - Row(0, Row(0, 1, Row(1, null, "2"))) :: - Row(1, Row(1, 2, Row(2, 3L, null))) :: Nil) - assert(unionDf.schema.toDDL == expectedSchema) + Row(0, Row(0, 1, Row(1, "2", null))) :: + Row(1, Row(1, 2, Row(2, null, 3L))) :: Nil) unionDf = df2.unionByName(df1, true) + assert(unionDf.schema.toDDL == + "`id` INT,`a` STRUCT<`a`: INT, `b`: BIGINT, " + + "`nested`: STRUCT<`a`: INT, `b`: BIGINT, `c`: STRING>>") checkAnswer(unionDf, Row(1, Row(1, 2, Row(2, 3L, null))) :: Row(0, Row(0, 1, Row(1, null, "2"))) :: Nil) - assert(unionDf.schema.toDDL == expectedSchema) val df3 = Seq((2, UnionClass1b(2, 3L, null))).toDF("id", "a") unionDf = df1.unionByName(df3, true) + assert(unionDf.schema.toDDL == + "`id` INT,`a` STRUCT<`a`: INT, `b`: BIGINT, " + + "`nested`: STRUCT<`a`: INT, `c`: STRING, `b`: BIGINT>>") checkAnswer(unionDf, - Row(0, Row(0, 1, Row(1, null, "2"))) :: + Row(0, Row(0, 1, Row(1, "2", null))) :: Row(2, Row(2, 3, null)) :: Nil) - assert(unionDf.schema.toDDL == expectedSchema) } test("SPARK-32376: Make unionByName null-filling behavior work with struct columns" + @@ -707,29 +710,29 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { val df2 = Seq((1, UnionClass1c(1, 2L, UnionClass4(2, 3L)))).toDF("id", "a") var unionDf = df1.unionByName(df2, true) - checkAnswer(unionDf, - Row(0, Row(0, 1, Row(null, 1, null, "2"))) :: - Row(1, Row(1, 2, Row(2, null, 3L, null))) :: Nil) assert(unionDf.schema.toDDL == "`id` INT,`a` STRUCT<`a`: INT, `b`: BIGINT, " + - "`nested`: STRUCT<`A`: INT, `a`: INT, `b`: BIGINT, `c`: STRING>>") + "`nested`: STRUCT<`a`: INT, `c`: STRING, `A`: INT, `b`: BIGINT>>") + checkAnswer(unionDf, + Row(0, Row(0, 1, Row(1, "2", null, null))) :: + Row(1, Row(1, 2, Row(null, null, 2, 3L))) :: Nil) unionDf = df2.unionByName(df1, true) - checkAnswer(unionDf, - Row(1, Row(1, 2, Row(2, null, 3L, null))) :: - Row(0, Row(0, 1, Row(null, 1, null, "2"))) :: Nil) assert(unionDf.schema.toDDL == "`id` INT,`a` STRUCT<`a`: INT, `b`: BIGINT, " + - "`nested`: STRUCT<`A`: INT, `a`: INT, `b`: BIGINT, `c`: STRING>>") + "`nested`: STRUCT<`A`: INT, `b`: BIGINT, `a`: INT, `c`: STRING>>") + checkAnswer(unionDf, + Row(1, Row(1, 2, Row(2, 3L, null, null))) :: + Row(0, Row(0, 1, Row(null, null, 1, "2"))) :: Nil) val df3 = Seq((2, UnionClass1b(2, 3L, UnionClass3(4, 5L)))).toDF("id", "a") unionDf = df2.unionByName(df3, true) - checkAnswer(unionDf, - Row(1, Row(1, 2, Row(2, null, 3L))) :: - Row(2, Row(2, 3, Row(null, 4, 5L))) :: Nil) assert(unionDf.schema.toDDL == "`id` INT,`a` STRUCT<`a`: INT, `b`: BIGINT, " + - "`nested`: STRUCT<`A`: INT, `a`: INT, `b`: BIGINT>>") + "`nested`: STRUCT<`A`: INT, `b`: BIGINT, `a`: INT>>") + checkAnswer(unionDf, + Row(1, Row(1, 2, Row(2, 3L, null))) :: + Row(2, Row(2, 3, Row(null, 5L, 4))) :: Nil) } } @@ -743,17 +746,59 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { StructField("a", StringType))) val nestedStructValues2 = Row("b", "a") - val df1: DataFrame = spark.createDataFrame( + val df1 = spark.createDataFrame( sparkContext.parallelize(Row(nestedStructValues1) :: Nil), StructType(Seq(StructField("topLevelCol", nestedStructType1)))) - val df2: DataFrame = spark.createDataFrame( + val df2 = spark.createDataFrame( sparkContext.parallelize(Row(nestedStructValues2) :: Nil), StructType(Seq(StructField("topLevelCol", nestedStructType2)))) val union = df1.unionByName(df2, allowMissingColumns = true) - checkAnswer(union, Row(Row(null, "b")) :: Row(Row("a", "b")) :: Nil) - assert(union.schema.toDDL == "`topLevelCol` STRUCT<`a`: STRING, `b`: STRING>") + assert(union.schema.toDDL == "`topLevelCol` STRUCT<`b`: STRING, `a`: STRING>") + checkAnswer(union, Row(Row("b", null)) :: Row(Row("b", "a")) :: Nil) + } + + test("SPARK-35290: Make unionByName null-filling behavior work with struct columns" + + " - sorting edge case") { + val nestedStructType1 = StructType(Seq( + StructField("b", StructType(Seq( + StructField("ba", StringType) + ))) + )) + val nestedStructValues1 = Row(Row("ba")) + + val nestedStructType2 = StructType(Seq( + StructField("a", StructType(Seq( + StructField("aa", StringType) + ))), + StructField("b", StructType(Seq( + StructField("bb", StringType) + ))) + )) + val nestedStructValues2 = Row(Row("aa"), Row("bb")) + + val df1 = spark.createDataFrame( + sparkContext.parallelize(Row(nestedStructValues1) :: Nil), + StructType(Seq(StructField("topLevelCol", nestedStructType1)))) + + val df2 = spark.createDataFrame( + sparkContext.parallelize(Row(nestedStructValues2) :: Nil), + StructType(Seq(StructField("topLevelCol", nestedStructType2)))) + + var unionDf = df1.unionByName(df2, true) + assert(unionDf.schema.toDDL == "`topLevelCol` " + + "STRUCT<`b`: STRUCT<`ba`: STRING, `bb`: STRING>, `a`: STRUCT<`aa`: STRING>>") + checkAnswer(unionDf, + Row(Row(Row("ba", null), null)) :: + Row(Row(Row(null, "bb"), Row("aa"))) :: Nil) + + unionDf = df2.unionByName(df1, true) + assert(unionDf.schema.toDDL == "`topLevelCol` STRUCT<`a`: STRUCT<`aa`: STRING>, " + + "`b`: STRUCT<`bb`: STRING, `ba`: STRING>>") + checkAnswer(unionDf, + Row(Row(null, Row(null, "ba"))) :: + Row(Row(Row("aa"), Row("bb", null))) :: Nil) } test("SPARK-32376: Make unionByName null-filling behavior work with struct columns - deep expr") { @@ -777,7 +822,7 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { depthCounter -= 1 } - val df: DataFrame = spark.createDataFrame( + val df = spark.createDataFrame( sparkContext.parallelize(Row(struct) :: Nil), StructType(Seq(StructField("nested0Col0", structType)))) @@ -800,16 +845,16 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { 1, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null), 1, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null)) val row2 = Row(Row(Row(Row(Row(Row(Row(Row(Row(Row( - Row(0, 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), - 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), - 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), - 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), - 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), - 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), - 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), - 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), - 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), - 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9)) + Row(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20), + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20), + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20), + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20), + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20), + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20), + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20), + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20), + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20), + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)) // scalastyle:on checkAnswer(union, row1 :: row2 :: Nil) }