From 97b7ee8dc2ee49bd1d49e0a1a45781affd1e0968 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Sat, 12 Mar 2022 15:23:39 +0800 Subject: [PATCH 01/13] [SPARK-38533][SQL] DS V2 aggregate push-down supports project with alias --- .../util/V2ExpressionSQLBuilder.java | 4 + .../catalyst/util/V2ExpressionBuilder.scala | 5 +- .../v2/V2ScanRelationPushDown.scala | 47 ++++++++-- .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 86 +++++++++++++++++-- 4 files changed, 130 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java index 0af0d88b0f622..d69ebb1d8b9b9 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java @@ -72,6 +72,8 @@ public String build(Expression expr) { return visitNot(build(e.children()[0])); case "~": return visitUnaryArithmetic(name, build(e.children()[0])); + case "AS": + return visitAs(build(e.children()[0]), build(e.children()[1])); case "CASE_WHEN": List children = new ArrayList<>(); for (Expression child : e.children()) { @@ -125,6 +127,8 @@ protected String visitNot(String v) { protected String visitUnaryArithmetic(String name, String v) { return name +" (" + v + ")"; } + protected String visitAs(String v, String name) { return v +" AS " + name; } + protected String visitCaseWhen(String[] children) { StringBuilder sb = new StringBuilder("CASE"); for (int i = 0; i < children.length; i += 2) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index 1e361695056a7..eece110be915b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -17,8 +17,9 @@ package org.apache.spark.sql.catalyst.util -import org.apache.spark.sql.catalyst.expressions.{Add, And, Attribute, BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, CaseWhen, Divide, EqualTo, Expression, IsNotNull, IsNull, Literal, Multiply, Not, Or, Remainder, Subtract, UnaryMinus} +import org.apache.spark.sql.catalyst.expressions.{Add, Alias, And, Attribute, BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, CaseWhen, Divide, EqualTo, Expression, IsNotNull, IsNull, Literal, Multiply, Not, Or, Remainder, Subtract, UnaryMinus} import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, FieldReference, GeneralScalarExpression, LiteralValue} +import org.apache.spark.sql.types.StringType /** * The builder to generate V2 expressions from catalyst expressions. @@ -68,6 +69,8 @@ class V2ExpressionBuilder(e: Expression) { .map(v => new GeneralScalarExpression("-", Array[V2Expression](v))) case BitwiseNot(child) => generateExpression(child) .map(v => new GeneralScalarExpression("~", Array[V2Expression](v))) + case Alias(col, name) => generateExpression(col).map(c => + new GeneralScalarExpression("AS", Array[V2Expression](c, LiteralValue(name, StringType)))) case CaseWhen(branches, elseValue) => val conditions = branches.map(_._1).flatMap(generateExpression) val values = branches.map(_._2).flatMap(generateExpression) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index b4bd02773edfb..1b39f5fbec53a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -93,22 +93,38 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { case aggNode @ Aggregate(groupingExpressions, resultExpressions, child) => child match { case ScanOperation(project, filters, sHolder: ScanBuilderHolder) - if filters.isEmpty && project.forall(_.isInstanceOf[AttributeReference]) => + if filters.isEmpty && + project.forall(p => p.isInstanceOf[AttributeReference] || p.isInstanceOf[Alias]) => sHolder.builder match { case r: SupportsPushDownAggregates => + val aliasAttrToOriginAttr = mutable.HashMap.empty[Expression, AttributeReference] + val originAttrToAliasAttr = mutable.HashMap.empty[Expression, Attribute] + collectAliases(project, aliasAttrToOriginAttr, originAttrToAliasAttr) + val newResultExpressions = resultExpressions.map { expr => + expr.transform { + case r: AttributeReference if aliasAttrToOriginAttr.contains(r.canonicalized) => + aliasAttrToOriginAttr(r.canonicalized) + } + }.asInstanceOf[Seq[NamedExpression]] + val newGroupingExpressions = groupingExpressions.map { expr => + expr.transform { + case r: AttributeReference if aliasAttrToOriginAttr.contains(r.canonicalized) => + aliasAttrToOriginAttr(r.canonicalized) + } + } val aggExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int] - val aggregates = collectAggregates(resultExpressions, aggExprToOutputOrdinal) + val aggregates = collectAggregates(newResultExpressions, aggExprToOutputOrdinal) val normalizedAggregates = DataSourceStrategy.normalizeExprs( aggregates, sHolder.relation.output).asInstanceOf[Seq[AggregateExpression]] val normalizedGroupingExpressions = DataSourceStrategy.normalizeExprs( - groupingExpressions, sHolder.relation.output) + newGroupingExpressions, sHolder.relation.output) val translatedAggregates = DataSourceStrategy.translateAggregation( normalizedAggregates, normalizedGroupingExpressions) val (finalResultExpressions, finalAggregates, finalTranslatedAggregates) = { if (translatedAggregates.isEmpty || r.supportCompletePushDown(translatedAggregates.get) || translatedAggregates.get.aggregateExpressions().forall(!_.isInstanceOf[Avg])) { - (resultExpressions, aggregates, translatedAggregates) + (newResultExpressions, aggregates, translatedAggregates) } else { // scalastyle:off // The data source doesn't support the complete push-down of this aggregation. @@ -203,8 +219,15 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { val wrappedScan = getWrappedScan(scan, sHolder, pushedAggregates) val scanRelation = DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output) + val resultExpressionWithAliases = finalResultExpressions.map { + case attr: AttributeReference + if originAttrToAliasAttr.contains(attr.canonicalized) => + val alias = originAttrToAliasAttr(attr.canonicalized) + Alias(attr, alias.name)(alias.exprId) + case other => other + } if (r.supportCompletePushDown(pushedAggregates.get)) { - val projectExpressions = resultExpressions.map { expr => + val projectExpressions = resultExpressionWithAliases.map { expr => // TODO At present, only push down group by attribute is supported. // In future, more attribute conversion is extended here. e.g. GetStructField expr.transform { @@ -218,7 +241,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { Project(projectExpressions, scanRelation) } else { val plan = Aggregate(output.take(groupingExpressions.length), - finalResultExpressions, scanRelation) + resultExpressionWithAliases, scanRelation) // scalastyle:off // Change the optimized logical plan to reflect the pushed down aggregate @@ -282,6 +305,18 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { } } + private def collectAliases(project: Seq[NamedExpression], + aliasAttrToOriginAttr: mutable.HashMap[Expression, AttributeReference], + originAttrToAliasAttr: mutable.HashMap[Expression, Attribute]) = { + project.collect { + case alias @ Alias(attr: AttributeReference, _) => + val output = alias.toAttribute + aliasAttrToOriginAttr(output.canonicalized) = attr + originAttrToAliasAttr(attr.canonicalized) = output + case other => other + } + } + private def supportPartialAggPushDown(agg: Aggregation): Boolean = { // We don't know the agg buffer of `GeneralAggregateFunc`, so can't do partial agg push down. // If `Sum`, `Count`, `Avg` with distinct, can't do partial agg push down. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 17bd7f7a6d5bc..b8bfc56538469 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -791,15 +791,19 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAnswer(df, Seq(Row(1d), Row(1d), Row(null))) } - test("scan with aggregate push-down: aggregate over alias NOT push down") { + test("scan with aggregate push-down: aggregate over alias push down") { val cols = Seq("a", "b", "c", "d") val df1 = sql("select * from h2.test.employee").toDF(cols: _*) val df2 = df1.groupBy().sum("c") - checkAggregateRemoved(df2, false) + checkAggregateRemoved(df2) df2.queryExecution.optimizedPlan.collect { - case relation: DataSourceV2ScanRelation => relation.scan match { - case v1: V1ScanWrapper => - assert(v1.pushedDownOperators.aggregation.isEmpty) + case relation: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: []" + checkKeywordsExistsInExplain(df2, expected_plan_fragment) + relation.scan match { + case v1: V1ScanWrapper => + assert(v1.pushedDownOperators.aggregation.nonEmpty) } } checkAnswer(df2, Seq(Row(53000.00))) @@ -1044,4 +1048,76 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel |ON h2.test.view1.`|col1` = h2.test.view2.`|col1`""".stripMargin) checkAnswer(df, Seq.empty[Row]) } + + test("scan with aggregate push-down: complete push-down aggregate with alias") { + val df = spark.table("h2.test.employee") + .select($"DEPT", $"SALARY".as("mySalary")) + .groupBy($"DEPT") + .agg(sum($"mySalary").as("total")) + .filter($"total" > 1000) + checkAggregateRemoved(df) + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [DEPT]" + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + checkAnswer(df, Seq(Row(1, 19000.00), Row(2, 22000.00), Row(6, 12000.00))) + + val df2 = spark.table("h2.test.employee") + .select($"DEPT".as("myDept"), $"SALARY".as("mySalary")) + .groupBy($"myDept") + .agg(sum($"mySalary").as("total")) + .filter($"total" > 1000) + checkAggregateRemoved(df2) + df2.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [DEPT]" + checkKeywordsExistsInExplain(df2, expected_plan_fragment) + } + checkAnswer(df2, Seq(Row(1, 19000.00), Row(2, 22000.00), Row(6, 12000.00))) + } + + test("scan with aggregate push-down: partial push-down aggregate with alias") { + val df = spark.read + .option("partitionColumn", "DEPT") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "2") + .table("h2.test.employee") + .select($"NAME", $"SALARY".as("mySalary")) + .groupBy($"NAME") + .agg(sum($"mySalary").as("total")) + .filter($"total" > 1000) + checkAggregateRemoved(df, false) + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [NAME]" + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + checkAnswer(df, Seq(Row("alex", 12000.00), Row("amy", 10000.00), + Row("cathy", 9000.00), Row("david", 10000.00), Row("jen", 12000.00))) + + val df2 = spark.read + .option("partitionColumn", "DEPT") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "2") + .table("h2.test.employee") + .select($"NAME".as("myName"), $"SALARY".as("mySalary")) + .groupBy($"myName") + .agg(sum($"mySalary").as("total")) + .filter($"total" > 1000) + checkAggregateRemoved(df2, false) + df2.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [NAME]" + checkKeywordsExistsInExplain(df2, expected_plan_fragment) + } + checkAnswer(df2, Seq(Row("alex", 12000.00), Row("amy", 10000.00), + Row("cathy", 9000.00), Row("david", 10000.00), Row("jen", 12000.00))) + } } From 611748132a5166c1d6c1c92b81a2ebc9bbac15f1 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Sat, 12 Mar 2022 19:27:27 +0800 Subject: [PATCH 02/13] Update code --- .../v2/V2ScanRelationPushDown.scala | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 1b39f5fbec53a..599d877e3abbe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -120,7 +120,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { newGroupingExpressions, sHolder.relation.output) val translatedAggregates = DataSourceStrategy.translateAggregation( normalizedAggregates, normalizedGroupingExpressions) - val (finalResultExpressions, finalAggregates, finalTranslatedAggregates) = { + val (selectedResultExpressions, selectedAggregates, selectedTranslatedAggregates) = { if (translatedAggregates.isEmpty || r.supportCompletePushDown(translatedAggregates.get) || translatedAggregates.get.aggregateExpressions().forall(!_.isInstanceOf[Avg])) { @@ -172,13 +172,13 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { } } - if (finalTranslatedAggregates.isEmpty) { + if (selectedTranslatedAggregates.isEmpty) { aggNode // return original plan node - } else if (!r.supportCompletePushDown(finalTranslatedAggregates.get) && - !supportPartialAggPushDown(finalTranslatedAggregates.get)) { + } else if (!r.supportCompletePushDown(selectedTranslatedAggregates.get) && + !supportPartialAggPushDown(selectedTranslatedAggregates.get)) { aggNode // return original plan node } else { - val pushedAggregates = finalTranslatedAggregates.filter(r.pushAggregation) + val pushedAggregates = selectedTranslatedAggregates.filter(r.pushAggregation) if (pushedAggregates.isEmpty) { aggNode // return original plan node } else { @@ -198,7 +198,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] // scalastyle:on val newOutput = scan.readSchema().toAttributes - assert(newOutput.length == groupingExpressions.length + finalAggregates.length) + assert(newOutput.length == groupingExpressions.length + selectedAggregates.length) val groupAttrs = normalizedGroupingExpressions.zip(newOutput).map { case (a: Attribute, b: Attribute) => b.withExprId(a.exprId) case (_, b) => b @@ -219,7 +219,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { val wrappedScan = getWrappedScan(scan, sHolder, pushedAggregates) val scanRelation = DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output) - val resultExpressionWithAliases = finalResultExpressions.map { + val finalResultExpressions = selectedResultExpressions.map { case attr: AttributeReference if originAttrToAliasAttr.contains(attr.canonicalized) => val alias = originAttrToAliasAttr(attr.canonicalized) @@ -227,7 +227,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { case other => other } if (r.supportCompletePushDown(pushedAggregates.get)) { - val projectExpressions = resultExpressionWithAliases.map { expr => + val projectExpressions = finalResultExpressions.map { expr => // TODO At present, only push down group by attribute is supported. // In future, more attribute conversion is extended here. e.g. GetStructField expr.transform { @@ -241,7 +241,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { Project(projectExpressions, scanRelation) } else { val plan = Aggregate(output.take(groupingExpressions.length), - resultExpressionWithAliases, scanRelation) + finalResultExpressions, scanRelation) // scalastyle:off // Change the optimized logical plan to reflect the pushed down aggregate From 4dead89354ed7e4e290c31bf1df0b61df44fa25a Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Sun, 13 Mar 2022 09:34:03 +0800 Subject: [PATCH 03/13] Update code --- .../datasources/FileSourceAggregatePushDownSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala index 47740c5274616..5c1fe9a6b98de 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala @@ -184,7 +184,7 @@ trait FileSourceAggregatePushDownSuite } } - test("aggregate over alias not push down") { + test("aggregate over alias push down") { val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), (9, "mno", 7), (2, null, 6)) withDataSourceTable(data, "t") { @@ -194,7 +194,7 @@ trait FileSourceAggregatePushDownSuite query.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = - "PushedAggregation: []" // aggregate alias not pushed down + "PushedAggregation: [MIN(_1)]" // aggregate alias not pushed down checkKeywordsExistsInExplain(query, expected_plan_fragment) } checkAnswer(query, Seq(Row(-2))) From a7c5504504e8902ed9795598fccac7a79768c455 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Sun, 13 Mar 2022 09:39:45 +0800 Subject: [PATCH 04/13] Update code --- .../spark/sql/connector/util/V2ExpressionSQLBuilder.java | 4 ---- .../apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala | 5 +---- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java index d69ebb1d8b9b9..0af0d88b0f622 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java @@ -72,8 +72,6 @@ public String build(Expression expr) { return visitNot(build(e.children()[0])); case "~": return visitUnaryArithmetic(name, build(e.children()[0])); - case "AS": - return visitAs(build(e.children()[0]), build(e.children()[1])); case "CASE_WHEN": List children = new ArrayList<>(); for (Expression child : e.children()) { @@ -127,8 +125,6 @@ protected String visitNot(String v) { protected String visitUnaryArithmetic(String name, String v) { return name +" (" + v + ")"; } - protected String visitAs(String v, String name) { return v +" AS " + name; } - protected String visitCaseWhen(String[] children) { StringBuilder sb = new StringBuilder("CASE"); for (int i = 0; i < children.length; i += 2) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index eece110be915b..1e361695056a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -17,9 +17,8 @@ package org.apache.spark.sql.catalyst.util -import org.apache.spark.sql.catalyst.expressions.{Add, Alias, And, Attribute, BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, CaseWhen, Divide, EqualTo, Expression, IsNotNull, IsNull, Literal, Multiply, Not, Or, Remainder, Subtract, UnaryMinus} +import org.apache.spark.sql.catalyst.expressions.{Add, And, Attribute, BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, CaseWhen, Divide, EqualTo, Expression, IsNotNull, IsNull, Literal, Multiply, Not, Or, Remainder, Subtract, UnaryMinus} import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, FieldReference, GeneralScalarExpression, LiteralValue} -import org.apache.spark.sql.types.StringType /** * The builder to generate V2 expressions from catalyst expressions. @@ -69,8 +68,6 @@ class V2ExpressionBuilder(e: Expression) { .map(v => new GeneralScalarExpression("-", Array[V2Expression](v))) case BitwiseNot(child) => generateExpression(child) .map(v => new GeneralScalarExpression("~", Array[V2Expression](v))) - case Alias(col, name) => generateExpression(col).map(c => - new GeneralScalarExpression("AS", Array[V2Expression](c, LiteralValue(name, StringType)))) case CaseWhen(branches, elseValue) => val conditions = branches.map(_._1).flatMap(generateExpression) val values = branches.map(_._2).flatMap(generateExpression) From 93e4fc4f23c0eca22f67e773bf8e07d6eafcb0ec Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Sun, 13 Mar 2022 09:41:12 +0800 Subject: [PATCH 05/13] Update code --- .../datasources/FileSourceAggregatePushDownSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala index 5c1fe9a6b98de..26dfe1a50971f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala @@ -194,7 +194,7 @@ trait FileSourceAggregatePushDownSuite query.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = - "PushedAggregation: [MIN(_1)]" // aggregate alias not pushed down + "PushedAggregation: [MIN(_1)]" checkKeywordsExistsInExplain(query, expected_plan_fragment) } checkAnswer(query, Seq(Row(-2))) From 1a764d26c97f4219323df04e9cefd5d679decec4 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Sun, 13 Mar 2022 09:42:21 +0800 Subject: [PATCH 06/13] Update code --- .../execution/datasources/v2/V2ScanRelationPushDown.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 599d877e3abbe..b4fbf9fd46163 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -92,9 +92,8 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { // update the scan builder with agg pushdown and return a new plan with agg pushed case aggNode @ Aggregate(groupingExpressions, resultExpressions, child) => child match { - case ScanOperation(project, filters, sHolder: ScanBuilderHolder) - if filters.isEmpty && - project.forall(p => p.isInstanceOf[AttributeReference] || p.isInstanceOf[Alias]) => + case ScanOperation(project, filters, sHolder: ScanBuilderHolder) if filters.isEmpty && + project.forall(p => p.isInstanceOf[AttributeReference] || p.isInstanceOf[Alias]) => sHolder.builder match { case r: SupportsPushDownAggregates => val aliasAttrToOriginAttr = mutable.HashMap.empty[Expression, AttributeReference] From 9ebcd8e3ee53d81f053661ceccda5cee384e99bf Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Mon, 14 Mar 2022 12:49:20 +0800 Subject: [PATCH 07/13] Update code --- .../v2/V2ScanRelationPushDown.scala | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index b4fbf9fd46163..5d4229a0cd7b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -99,18 +99,17 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { val aliasAttrToOriginAttr = mutable.HashMap.empty[Expression, AttributeReference] val originAttrToAliasAttr = mutable.HashMap.empty[Expression, Attribute] collectAliases(project, aliasAttrToOriginAttr, originAttrToAliasAttr) - val newResultExpressions = resultExpressions.map { expr => - expr.transform { - case r: AttributeReference if aliasAttrToOriginAttr.contains(r.canonicalized) => - aliasAttrToOriginAttr(r.canonicalized) - } - }.asInstanceOf[Seq[NamedExpression]] - val newGroupingExpressions = groupingExpressions.map { expr => - expr.transform { - case r: AttributeReference if aliasAttrToOriginAttr.contains(r.canonicalized) => - aliasAttrToOriginAttr(r.canonicalized) - } + def replaceAliasWithAttr(expressions: Seq[Expression]): Seq[NamedExpression] = { + expressions.map { expr => + expr.transform { + case r: AttributeReference if aliasAttrToOriginAttr.contains(r.canonicalized) => + aliasAttrToOriginAttr(r.canonicalized) + } + }.asInstanceOf[Seq[NamedExpression]] } + + val newResultExpressions = replaceAliasWithAttr(resultExpressions) + val newGroupingExpressions = replaceAliasWithAttr(groupingExpressions) val aggExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int] val aggregates = collectAggregates(newResultExpressions, aggExprToOutputOrdinal) val normalizedAggregates = DataSourceStrategy.normalizeExprs( From 864d52d4161fdd097b4fff6ab015e3a84216cb00 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Mon, 14 Mar 2022 12:55:15 +0800 Subject: [PATCH 08/13] Update code --- .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 148 +++++++++--------- 1 file changed, 74 insertions(+), 74 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index b8bfc56538469..33fe25ef9e38f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -263,9 +263,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedFilters: [IsNotNull(ID), GreaterThan(ID,1)]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } checkAnswer(df, Row("mary", 2)) @@ -410,11 +410,11 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [MAX(SALARY), AVG(BONUS)], " + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + "PushedGroupByColumns: [DEPT]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } checkAnswer(df, Seq(Row(10000, 1100.0), Row(12000, 1250.0), Row(12000, 1200.0))) } @@ -432,11 +432,11 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [MAX(ID), AVG(ID)], " + "PushedFilters: [IsNotNull(ID), GreaterThan(ID,0)], " + "PushedGroupByColumns: []" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } checkAnswer(df, Seq(Row(2, 1.5))) } @@ -463,9 +463,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [MAX(SALARY)]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } checkAnswer(df, Seq(Row(12001))) } @@ -475,9 +475,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [COUNT(*)]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } checkAnswer(df, Seq(Row(5))) } @@ -487,9 +487,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [COUNT(DEPT)]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } checkAnswer(df, Seq(Row(5))) } @@ -499,9 +499,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [COUNT(DISTINCT DEPT)]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } checkAnswer(df, Seq(Row(3))) } @@ -523,9 +523,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [SUM(SALARY)]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } checkAnswer(df, Seq(Row(53000))) } @@ -535,9 +535,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [SUM(DISTINCT SALARY)]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } checkAnswer(df, Seq(Row(31000))) } @@ -547,11 +547,11 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [SUM(SALARY)], " + "PushedFilters: [], " + "PushedGroupByColumns: [DEPT]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } checkAnswer(df, Seq(Row(19000), Row(22000), Row(12000))) } @@ -561,11 +561,11 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [SUM(DISTINCT SALARY)], " + "PushedFilters: [], " + "PushedGroupByColumns: [DEPT]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } checkAnswer(df, Seq(Row(19000), Row(22000), Row(12000))) } @@ -577,11 +577,11 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + "PushedGroupByColumns: [DEPT, NAME]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } checkAnswer(df, Seq(Row(9000, 1200), Row(12000, 1200), Row(10000, 1300), Row(10000, 1000), Row(12000, 1200))) @@ -597,11 +597,11 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df1) df1.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [MAX(SALARY)], " + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + "PushedGroupByColumns: [DEPT, NAME]" - checkKeywordsExistsInExplain(df1, expected_plan_fragment) + checkKeywordsExistsInExplain(df1, expectedPlanFragment) } checkAnswer(df1, Seq(Row("1#amy", 10000), Row("1#cathy", 9000), Row("2#alex", 12000), Row("2#david", 10000), Row("6#jen", 12000))) @@ -615,11 +615,11 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df2) df2.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + "PushedGroupByColumns: [DEPT, NAME]" - checkKeywordsExistsInExplain(df2, expected_plan_fragment) + checkKeywordsExistsInExplain(df2, expectedPlanFragment) } checkAnswer(df2, Seq(Row("1#amy", 11000), Row("1#cathy", 10200), Row("2#alex", 13200), Row("2#david", 11300), Row("6#jen", 13200))) @@ -633,9 +633,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df3, false) df3.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " - checkKeywordsExistsInExplain(df3, expected_plan_fragment) + checkKeywordsExistsInExplain(df3, expectedPlanFragment) } checkAnswer(df3, Seq(Row("1#amy", 11000), Row("1#cathy", 10200), Row("2#alex", 13200), Row("2#david", 11300), Row("6#jen", 13200))) @@ -651,11 +651,11 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + "PushedGroupByColumns: [DEPT]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } checkAnswer(df, Seq(Row(12000, 1200), Row(12000, 1200))) } @@ -667,11 +667,11 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [MIN(SALARY)], " + "PushedFilters: [], " + "PushedGroupByColumns: [DEPT]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } checkAnswer(df, Seq(Row(1, 9000), Row(2, 10000), Row(6, 12000))) } @@ -691,11 +691,11 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df) query.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [SUM(SALARY)], " + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + "PushedGroupByColumns: [DEPT]" - checkKeywordsExistsInExplain(query, expected_plan_fragment) + checkKeywordsExistsInExplain(query, expectedPlanFragment) } checkAnswer(query, Seq(Row(6, 12000), Row(1, 19000), Row(2, 22000))) } @@ -707,9 +707,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(query) query.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [SUM(SALARY), SUM(BONUS)]" - checkKeywordsExistsInExplain(query, expected_plan_fragment) + checkKeywordsExistsInExplain(query, expectedPlanFragment) } checkAnswer(query, Seq(Row(47100.0))) } @@ -734,11 +734,11 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [VAR_POP(BONUS), VAR_SAMP(BONUS)], " + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + "PushedGroupByColumns: [DEPT]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } checkAnswer(df, Seq(Row(10000d, 20000d), Row(2500d, 5000d), Row(0d, null))) } @@ -750,11 +750,11 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [STDDEV_POP(BONUS), STDDEV_SAMP(BONUS)], " + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + "PushedGroupByColumns: [DEPT]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } checkAnswer(df, Seq(Row(100d, 141.4213562373095d), Row(50d, 70.71067811865476d), Row(0d, null))) } @@ -766,11 +766,11 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [COVAR_POP(BONUS, BONUS), COVAR_SAMP(BONUS, BONUS)], " + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + "PushedGroupByColumns: [DEPT]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } checkAnswer(df, Seq(Row(10000d, 20000d), Row(2500d, 5000d), Row(0d, null))) } @@ -782,11 +782,11 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [CORR(BONUS, BONUS)], " + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + "PushedGroupByColumns: [DEPT]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } checkAnswer(df, Seq(Row(1d), Row(1d), Row(null))) } @@ -798,9 +798,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df2) df2.queryExecution.optimizedPlan.collect { case relation: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: []" - checkKeywordsExistsInExplain(df2, expected_plan_fragment) + checkKeywordsExistsInExplain(df2, expectedPlanFragment) relation.scan match { case v1: V1ScanWrapper => assert(v1.pushedDownOperators.aggregation.nonEmpty) @@ -851,12 +851,12 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [COUNT(CASE WHEN ((SALARY) > (8000.00)) AND ((SALARY) < (10000.00))" + " THEN SALARY ELSE 0.00 END), C..., " + "PushedFilters: [], " + "PushedGroupByColumns: [DEPT]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } checkAnswer(df, Seq(Row(1, 1, 1, 1, 1, 0d, 12000d, 0d, 12000d, 12000d, 0d, 0d, 2, 0d), Row(2, 2, 2, 2, 2, 0d, 10000d, 0d, 10000d, 10000d, 0d, 0d, 2, 0d), @@ -868,7 +868,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiMode.toString) { val df = sql("SELECT SUM(2147483647 + DEPT) FROM h2.test.employee") checkAggregateRemoved(df, ansiMode) - val expected_plan_fragment = if (ansiMode) { + val expectedPlanFragment = if (ansiMode) { "PushedAggregates: [SUM((2147483647) + (DEPT))], " + "PushedFilters: [], PushedGroupByColumns: []" } else { @@ -876,7 +876,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel } df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } if (ansiMode) { val e = intercept[SparkException] { @@ -898,9 +898,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(query, false) query.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedFilters: []" - checkKeywordsExistsInExplain(query, expected_plan_fragment) + checkKeywordsExistsInExplain(query, expectedPlanFragment) } checkAnswer(query, Seq(Row(47100.0))) } @@ -939,9 +939,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [COUNT(`dept id`)]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } checkAnswer(df, Seq(Row(2))) } @@ -953,9 +953,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [COUNT(`名`)]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } checkAnswer(df, Seq(Row(2))) // scalastyle:on @@ -972,9 +972,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [SUM(SALARY), AVG(SALARY), COUNT(SALARY)]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } checkAnswer(df, Seq(Row(53000.00, 10600.000000, 5))) @@ -989,9 +989,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [SUM(SALARY), AVG(SALARY), COUNT(SALARY)]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } checkAnswer(df2, Seq( Row("alex", 12000.00, 12000.000000, 1), @@ -1012,9 +1012,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df, false) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [SUM(SALARY), COUNT(SALARY)]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } checkAnswer(df, Seq(Row(53000.00, 10600.000000, 5))) @@ -1029,9 +1029,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df, false) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [SUM(SALARY), COUNT(SALARY)]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } checkAnswer(df2, Seq( Row("alex", 12000.00, 12000.000000, 1), @@ -1058,9 +1058,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [DEPT]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } checkAnswer(df, Seq(Row(1, 19000.00), Row(2, 22000.00), Row(6, 12000.00))) @@ -1072,9 +1072,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df2) df2.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [DEPT]" - checkKeywordsExistsInExplain(df2, expected_plan_fragment) + checkKeywordsExistsInExplain(df2, expectedPlanFragment) } checkAnswer(df2, Seq(Row(1, 19000.00), Row(2, 22000.00), Row(6, 12000.00))) } @@ -1093,9 +1093,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df, false) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [NAME]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } checkAnswer(df, Seq(Row("alex", 12000.00), Row("amy", 10000.00), Row("cathy", 9000.00), Row("david", 10000.00), Row("jen", 12000.00))) @@ -1113,9 +1113,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df2, false) df2.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [NAME]" - checkKeywordsExistsInExplain(df2, expected_plan_fragment) + checkKeywordsExistsInExplain(df2, expectedPlanFragment) } checkAnswer(df2, Seq(Row("alex", 12000.00), Row("amy", 10000.00), Row("cathy", 9000.00), Row("david", 10000.00), Row("jen", 12000.00))) From 2be235c1f077c82649d0e6263caf58af07b4dfd1 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Tue, 15 Mar 2022 16:28:34 +0800 Subject: [PATCH 09/13] Update code --- .../catalyst/expressions/AliasHelper.scala | 19 ++++++++ .../v2/V2ScanRelationPushDown.scala | 43 +++++-------------- 2 files changed, 29 insertions(+), 33 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AliasHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AliasHelper.scala index dea7ea0f144bf..844f2bf6c49cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AliasHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AliasHelper.scala @@ -50,6 +50,14 @@ trait AliasHelper { AttributeMap(exprs.collect { case a: Alias => (a.toAttribute, a) }) } + protected def getAttrToAliasMap(aliasMap: AttributeMap[Alias]): AttributeMap[Alias] = { + val attrToAliasMap = aliasMap.values.toSeq.collect { + case alias @ Alias(originAttr: Attribute, _) => + (originAttr, alias) + } + AttributeMap(attrToAliasMap) + } + /** * Replace all attributes, that reference an alias, with the aliased expression */ @@ -77,6 +85,17 @@ trait AliasHelper { }).asInstanceOf[NamedExpression] } + /** + * Replace all alias, with the aliased attribute. + */ + protected def replaceAliasWithAttr( + expr: NamedExpression, + aliasMap: AttributeMap[Alias]): NamedExpression = { + replaceAliasButKeepName(expr, aliasMap).transform { + case Alias(attr: Attribute, _) => attr + }.asInstanceOf[NamedExpression] + } + protected def trimAliases(e: Expression): Expression = { e.transformDown { case Alias(child, _) => child diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 5d4229a0cd7b6..a6846dbfee7ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -19,8 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, Cast, Divide, DivideDTInterval, DivideYMInterval, EqualTo, Expression, If, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression} -import org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.sql.catalyst.expressions.{aggregate, Alias, AliasHelper, And, Attribute, AttributeReference, Cast, Divide, DivideDTInterval, DivideYMInterval, EqualTo, Expression, If, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.planning.ScanOperation import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LocalLimit, LogicalPlan, Project, Sample, Sort} @@ -33,7 +32,7 @@ import org.apache.spark.sql.sources import org.apache.spark.sql.types.{DataType, DayTimeIntervalType, LongType, StructType, YearMonthIntervalType} import org.apache.spark.sql.util.SchemaUtils._ -object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { +object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper with AliasHelper { import DataSourceV2Implicits._ def apply(plan: LogicalPlan): LogicalPlan = { @@ -96,20 +95,11 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { project.forall(p => p.isInstanceOf[AttributeReference] || p.isInstanceOf[Alias]) => sHolder.builder match { case r: SupportsPushDownAggregates => - val aliasAttrToOriginAttr = mutable.HashMap.empty[Expression, AttributeReference] - val originAttrToAliasAttr = mutable.HashMap.empty[Expression, Attribute] - collectAliases(project, aliasAttrToOriginAttr, originAttrToAliasAttr) - def replaceAliasWithAttr(expressions: Seq[Expression]): Seq[NamedExpression] = { - expressions.map { expr => - expr.transform { - case r: AttributeReference if aliasAttrToOriginAttr.contains(r.canonicalized) => - aliasAttrToOriginAttr(r.canonicalized) - } - }.asInstanceOf[Seq[NamedExpression]] - } - - val newResultExpressions = replaceAliasWithAttr(resultExpressions) - val newGroupingExpressions = replaceAliasWithAttr(groupingExpressions) + val aliasMap = getAliasMap(project) + val newResultExpressions = + resultExpressions.map(replaceAliasWithAttr(_, aliasMap)) + val newGroupingExpressions = groupingExpressions.asInstanceOf[Seq[NamedExpression]] + .map(replaceAliasWithAttr(_, aliasMap)) val aggExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int] val aggregates = collectAggregates(newResultExpressions, aggExprToOutputOrdinal) val normalizedAggregates = DataSourceStrategy.normalizeExprs( @@ -217,11 +207,10 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { val wrappedScan = getWrappedScan(scan, sHolder, pushedAggregates) val scanRelation = DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output) + val aliasAttrMap = getAttrToAliasMap(aliasMap) val finalResultExpressions = selectedResultExpressions.map { - case attr: AttributeReference - if originAttrToAliasAttr.contains(attr.canonicalized) => - val alias = originAttrToAliasAttr(attr.canonicalized) - Alias(attr, alias.name)(alias.exprId) + case attr: AttributeReference => + aliasAttrMap.getOrElse(attr, attr) case other => other } if (r.supportCompletePushDown(pushedAggregates.get)) { @@ -303,18 +292,6 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { } } - private def collectAliases(project: Seq[NamedExpression], - aliasAttrToOriginAttr: mutable.HashMap[Expression, AttributeReference], - originAttrToAliasAttr: mutable.HashMap[Expression, Attribute]) = { - project.collect { - case alias @ Alias(attr: AttributeReference, _) => - val output = alias.toAttribute - aliasAttrToOriginAttr(output.canonicalized) = attr - originAttrToAliasAttr(attr.canonicalized) = output - case other => other - } - } - private def supportPartialAggPushDown(agg: Aggregation): Boolean = { // We don't know the agg buffer of `GeneralAggregateFunc`, so can't do partial agg push down. // If `Sum`, `Count`, `Avg` with distinct, can't do partial agg push down. From 08a1b8ec3670a8a5758e36893f487f25f5d8c331 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Tue, 15 Mar 2022 17:13:27 +0800 Subject: [PATCH 10/13] Update code --- .../sql/execution/datasources/v2/V2ScanRelationPushDown.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index a6846dbfee7ac..8656e06e7970e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -96,8 +96,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit sHolder.builder match { case r: SupportsPushDownAggregates => val aliasMap = getAliasMap(project) - val newResultExpressions = - resultExpressions.map(replaceAliasWithAttr(_, aliasMap)) + val newResultExpressions = resultExpressions.map(replaceAliasWithAttr(_, aliasMap)) val newGroupingExpressions = groupingExpressions.asInstanceOf[Seq[NamedExpression]] .map(replaceAliasWithAttr(_, aliasMap)) val aggExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int] From 374e682d09551f63b836c54f96ac6b2eb03ef231 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Wed, 16 Mar 2022 10:13:41 +0800 Subject: [PATCH 11/13] Update code --- .../datasources/v2/V2ScanRelationPushDown.scala | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 8656e06e7970e..4e4b0d8b31421 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -91,14 +91,16 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit // update the scan builder with agg pushdown and return a new plan with agg pushed case aggNode @ Aggregate(groupingExpressions, resultExpressions, child) => child match { - case ScanOperation(project, filters, sHolder: ScanBuilderHolder) if filters.isEmpty && - project.forall(p => p.isInstanceOf[AttributeReference] || p.isInstanceOf[Alias]) => + case ScanOperation(project, filters, sHolder: ScanBuilderHolder) + if filters.isEmpty && project.forall(_.deterministic) => sHolder.builder match { case r: SupportsPushDownAggregates => val aliasMap = getAliasMap(project) val newResultExpressions = resultExpressions.map(replaceAliasWithAttr(_, aliasMap)) - val newGroupingExpressions = groupingExpressions.asInstanceOf[Seq[NamedExpression]] - .map(replaceAliasWithAttr(_, aliasMap)) + val newGroupingExpressions = groupingExpressions.map { + case e: NamedExpression => replaceAliasWithAttr(e, aliasMap) + case other => other + } val aggExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int] val aggregates = collectAggregates(newResultExpressions, aggExprToOutputOrdinal) val normalizedAggregates = DataSourceStrategy.normalizeExprs( From f0d0e7015bca8396894c03ebddf5abba574ab474 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Thu, 17 Mar 2022 13:46:36 +0800 Subject: [PATCH 12/13] Update code --- .../catalyst/expressions/AliasHelper.scala | 19 ----------- .../v2/V2ScanRelationPushDown.scala | 34 ++++++++++++++++++- 2 files changed, 33 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AliasHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AliasHelper.scala index 844f2bf6c49cc..dea7ea0f144bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AliasHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AliasHelper.scala @@ -50,14 +50,6 @@ trait AliasHelper { AttributeMap(exprs.collect { case a: Alias => (a.toAttribute, a) }) } - protected def getAttrToAliasMap(aliasMap: AttributeMap[Alias]): AttributeMap[Alias] = { - val attrToAliasMap = aliasMap.values.toSeq.collect { - case alias @ Alias(originAttr: Attribute, _) => - (originAttr, alias) - } - AttributeMap(attrToAliasMap) - } - /** * Replace all attributes, that reference an alias, with the aliased expression */ @@ -85,17 +77,6 @@ trait AliasHelper { }).asInstanceOf[NamedExpression] } - /** - * Replace all alias, with the aliased attribute. - */ - protected def replaceAliasWithAttr( - expr: NamedExpression, - aliasMap: AttributeMap[Alias]): NamedExpression = { - replaceAliasButKeepName(expr, aliasMap).transform { - case Alias(attr: Attribute, _) => attr - }.asInstanceOf[NamedExpression] - } - protected def trimAliases(e: Expression): Expression = { e.transformDown { case Alias(child, _) => child diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 4e4b0d8b31421..9b0ac9ad84b94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -19,8 +19,9 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.{aggregate, Alias, AliasHelper, And, Attribute, AttributeReference, Cast, Divide, DivideDTInterval, DivideYMInterval, EqualTo, Expression, If, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.{aggregate, Alias, AliasHelper, And, Attribute, AttributeMap, AttributeReference, Cast, Divide, DivideDTInterval, DivideYMInterval, EqualTo, Expression, If, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.optimizer.CollapseProject.{buildCleanedProjectList, canCollapseExpressions} import org.apache.spark.sql.catalyst.planning.ScanOperation import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LocalLimit, LogicalPlan, Project, Sample, Sort} import org.apache.spark.sql.catalyst.rules.Rule @@ -28,6 +29,7 @@ import org.apache.spark.sql.connector.expressions.SortOrder import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Avg, Count, GeneralAggregateFunc, Sum} import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources import org.apache.spark.sql.types.{DataType, DayTimeIntervalType, LongType, StructType, YearMonthIntervalType} import org.apache.spark.sql.util.SchemaUtils._ @@ -40,6 +42,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit createScanBuilder, pushDownSample, pushDownFilters, + collapseProject, pushDownAggregates, pushDownLimits, pruneColumns) @@ -87,6 +90,16 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit filterCondition.map(Filter(_, sHolder)).getOrElse(sHolder) } + private def collapseProject(plan: LogicalPlan): LogicalPlan = { + val alwaysInline = conf.getConf(SQLConf.COLLAPSE_PROJECT_ALWAYS_INLINE) + plan transformUp { + case agg @ Aggregate(_, aggregateExpressions, p: Project) + if canCollapseExpressions(aggregateExpressions, p.projectList, alwaysInline) => + agg.copy(aggregateExpressions = buildCleanedProjectList( + aggregateExpressions, p.projectList)) + } + } + def pushDownAggregates(plan: LogicalPlan): LogicalPlan = plan.transform { // update the scan builder with agg pushdown and return a new plan with agg pushed case aggNode @ Aggregate(groupingExpressions, resultExpressions, child) => @@ -276,6 +289,25 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit } } + /** + * Replace all alias, with the aliased attribute. + */ + private def replaceAliasWithAttr( + expr: NamedExpression, + aliasMap: AttributeMap[Alias]): NamedExpression = { + replaceAliasButKeepName(expr, aliasMap).transform { + case Alias(attr: Attribute, _) => attr + }.asInstanceOf[NamedExpression] + } + + protected def getAttrToAliasMap(aliasMap: AttributeMap[Alias]): AttributeMap[Alias] = { + val attrToAliasMap = aliasMap.values.toSeq.collect { + case alias @ Alias(originAttr: Attribute, _) => + (originAttr, alias) + } + AttributeMap(attrToAliasMap) + } + private def collectAggregates(resultExpressions: Seq[NamedExpression], aggExprToOutputOrdinal: mutable.HashMap[Expression, Int]): Seq[AggregateExpression] = { var ordinal = 0 From 7a145e8d9ac917a5e5553605e0374c6a88461f11 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Thu, 17 Mar 2022 15:24:18 +0800 Subject: [PATCH 13/13] Update code --- .../datasources/v2/V2ScanRelationPushDown.scala | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 9b0ac9ad84b94..7542599793e19 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -42,7 +42,6 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit createScanBuilder, pushDownSample, pushDownFilters, - collapseProject, pushDownAggregates, pushDownLimits, pruneColumns) @@ -100,7 +99,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit } } - def pushDownAggregates(plan: LogicalPlan): LogicalPlan = plan.transform { + def pushDownAggregates(plan: LogicalPlan): LogicalPlan = collapseProject(plan).transform { // update the scan builder with agg pushdown and return a new plan with agg pushed case aggNode @ Aggregate(groupingExpressions, resultExpressions, child) => child match { @@ -175,14 +174,14 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit } if (selectedTranslatedAggregates.isEmpty) { - aggNode // return original plan node + return plan // return original plan node } else if (!r.supportCompletePushDown(selectedTranslatedAggregates.get) && !supportPartialAggPushDown(selectedTranslatedAggregates.get)) { - aggNode // return original plan node + return plan // return original plan node } else { val pushedAggregates = selectedTranslatedAggregates.filter(r.pushAggregation) if (pushedAggregates.isEmpty) { - aggNode // return original plan node + return plan // return original plan node } else { // No need to do column pruning because only the aggregate columns are used as // DataSourceV2ScanRelation output columns. All the other columns are not @@ -283,9 +282,9 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit } } } - case _ => aggNode + case _ => return plan } - case _ => aggNode + case _ => return plan } }