From e8d28b7021db02b14c635f51d3d7200c399880e5 Mon Sep 17 00:00:00 2001 From: allisonwang-db Date: Tue, 26 Apr 2022 22:39:44 -0700 Subject: [PATCH 1/2] [SPARK-38918][SQL] Nested column pruning should filter out attributes that do not belong to the current relation This PR updates `ProjectionOverSchema` to use the outputs of the data source relation to filter the attributes in the nested schema pruning. This is needed because the attributes in the schema do not necessarily belong to the current data source relation. For example, if a filter contains a correlated subquery, then the subquery's children can contain attributes from both the inner query and the outer query. Since the `RewriteSubquery` batch happens after early scan pushdown rules, nested schema pruning can wrongly use the inner query's attributes to prune the outer query data schema, thus causing wrong results and unexpected exceptions. To fix a bug in `SchemaPruning`. No Unit test Closes #36216 from allisonwang-db/spark-38918-nested-column-pruning. Authored-by: allisonwang-db Signed-off-by: Liang-Chi Hsieh (cherry picked from commit 150434b5d7909dcf8248ffa5ec3d937ea3da09fd) Signed-off-by: Liang-Chi Hsieh (cherry picked from commit 793ba608181b3eba8f1f57fcdd12dcd3fe035362) Signed-off-by: allisonwang-db --- .../expressions/ProjectionOverSchema.scala | 8 +++- .../sql/catalyst/optimizer/Optimizer.scala | 1 + .../sql/catalyst/optimizer/objects.scala | 2 +- .../execution/datasources/SchemaPruning.scala | 2 +- .../v2/V2ScanRelationPushDown.scala | 5 ++- .../datasources/SchemaPruningSuite.scala | 45 ++++++++++++++++++- 6 files changed, 56 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala index 03b5517f6df05..3e253e43ff9f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala @@ -24,15 +24,19 @@ import org.apache.spark.sql.types._ * field indexes and field counts of complex type extractors and attributes * are adjusted to fit the schema. All other expressions are left as-is. This * class is motivated by columnar nested schema pruning. + * + * @param schema nested column schema + * @param output output attributes of the data source relation. They are used to filter out + * attributes in the schema that do not belong to the current relation. */ -case class ProjectionOverSchema(schema: StructType) { +case class ProjectionOverSchema(schema: StructType, output: AttributeSet) { private val fieldNames = schema.fieldNames.toSet def unapply(expr: Expression): Option[Expression] = getProjection(expr) private def getProjection(expr: Expression): Option[Expression] = expr match { - case a: AttributeReference if fieldNames.contains(a.name) => + case a: AttributeReference if fieldNames.contains(a.name) && output.contains(a) => Some(a.copy(dataType = schema(a.name).dataType)(a.exprId, a.qualifier)) case GetArrayItem(child, arrayItemOrdinal, failOnError) => getProjection(child).map { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index e39fa23168bfd..fe6e9c52b04b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -58,6 +58,7 @@ abstract class Optimizer(catalogManager: CatalogManager) override protected val excludedOnceBatches: Set[String] = Set( "PartitionPruning", + "RewriteSubquery", "Extract Python UDFs") protected def fixedPoint = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala index 9643f5827b910..abc6e3d166814 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala @@ -225,7 +225,7 @@ object ObjectSerializerPruning extends Rule[LogicalPlan] { } // Builds new projection. - val projectionOverSchema = ProjectionOverSchema(prunedSchema) + val projectionOverSchema = ProjectionOverSchema(prunedSchema, AttributeSet(s.output)) val newProjects = p.projectList.map(_.transformDown { case projectionOverSchema(expr) => expr }).map { case expr: NamedExpression => expr } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala index 55ae49e31a1a5..7b9bbd2cc7749 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala @@ -82,7 +82,7 @@ object SchemaPruning extends Rule[LogicalPlan] { // in dataSchema. if (countLeaves(dataSchema) > countLeaves(prunedDataSchema)) { val prunedRelation = leafNodeBuilder(prunedDataSchema) - val projectionOverSchema = ProjectionOverSchema(prunedDataSchema) + val projectionOverSchema = ProjectionOverSchema(prunedDataSchema, AttributeSet(output)) Some(buildNewProjection(projects, normalizedProjects, normalizedFilters, prunedRelation, projectionOverSchema)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index daef7571806f7..7da9b63fc79c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.v2 -import org.apache.spark.sql.catalyst.expressions.{And, Expression, NamedExpression, ProjectionOverSchema, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.{And, AttributeSet, Expression, NamedExpression, ProjectionOverSchema, SubqueryExpression} import org.apache.spark.sql.catalyst.planning.ScanOperation import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule @@ -67,7 +67,8 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] { val scanRelation = DataSourceV2ScanRelation(relation, wrappedScan, output) - val projectionOverSchema = ProjectionOverSchema(output.toStructType) + val projectionOverSchema = + ProjectionOverSchema(output.toStructType, AttributeSet(output)) val projectionFunc = (expr: Expression) => expr transformDown { case projectionOverSchema(newExpr) => newExpr } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala index 765d2fc584a7d..f2e24ff330edd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala @@ -57,11 +57,15 @@ abstract class SchemaPruningSuite contactId: Int, employer: Employer) + case class Employee(id: Int, name: FullName, employer: Company) + val janeDoe = FullName("Jane", "X.", "Doe") val johnDoe = FullName("John", "Y.", "Doe") val susanSmith = FullName("Susan", "Z.", "Smith") - val employer = Employer(0, Company("abc", "123 Business Street")) + val company = Company("abc", "123 Business Street") + + val employer = Employer(0, company) val employerWithNullCompany = Employer(1, null) val employerWithNullCompany2 = Employer(2, null) @@ -77,6 +81,8 @@ abstract class SchemaPruningSuite Department(1, "Marketing", 1, employerWithNullCompany) :: Department(2, "Operation", 4, employerWithNullCompany2) :: Nil + val employees = Employee(0, janeDoe, company) :: Employee(1, johnDoe, company) :: Nil + case class Name(first: String, last: String) case class BriefContact(id: Int, name: Name, address: String) @@ -580,6 +586,26 @@ abstract class SchemaPruningSuite } } + testSchemaPruning("SPARK-38918: nested schema pruning with correlated subqueries") { + withContacts { + withEmployees { + val query = sql( + """ + |select count(*) + |from contacts c + |where not exists (select null from employees e where e.name.first = c.name.first + | and e.employer.name = c.employer.company.name) + |""".stripMargin) + checkScan(query, + "struct," + + "employer:struct>>", + "struct," + + "employer:struct>") + checkAnswer(query, Row(3)) + } + } + } + protected def testSchemaPruning(testName: String)(testThunk: => Unit): Unit = { test(s"Spark vectorized reader - without partition data column - $testName") { withSQLConf(vectorizedReaderEnabledKey -> "true") { @@ -660,6 +686,23 @@ abstract class SchemaPruningSuite } } + private def withEmployees(testThunk: => Unit): Unit = { + withTempPath { dir => + val path = dir.getCanonicalPath + + makeDataSourceFile(employees, new File(path + "/employees")) + + // Providing user specified schema. Inferred schema from different data sources might + // be different. + val schema = "`id` INT,`name` STRUCT<`first`: STRING, `middle`: STRING, `last`: STRING>, " + + "`employer` STRUCT<`name`: STRING, `address`: STRING>" + spark.read.format(dataSourceName).schema(schema).load(path + "/employees") + .createOrReplaceTempView("employees") + + testThunk + } + } + case class MixedCaseColumn(a: String, B: Int) case class MixedCase(id: Int, CoL1: String, coL2: MixedCaseColumn) From 37d8e9ee99732c5bc53fcce1395e74623757effe Mon Sep 17 00:00:00 2001 From: allisonwang-db Date: Wed, 27 Apr 2022 21:46:01 -0700 Subject: [PATCH 2/2] fix test --- .../spark/sql/execution/datasources/SchemaPruningSuite.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala index f2e24ff330edd..b55dd20913188 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala @@ -597,8 +597,7 @@ abstract class SchemaPruningSuite | and e.employer.name = c.employer.company.name) |""".stripMargin) checkScan(query, - "struct," + - "employer:struct>>", + "struct,employer:struct>>", "struct," + "employer:struct>") checkAnswer(query, Row(3))