From dbc38d371ea3615ac6e756ade27f6fdefafa1feb Mon Sep 17 00:00:00 2001 From: allisonwang-db <66282705+allisonwang-db@users.noreply.github.com> Date: Thu, 19 Nov 2020 13:29:01 +0000 Subject: [PATCH] [SPARK-33472][SQL] Adjust RemoveRedundantSorts rule order This PR switched the order for the rule `RemoveRedundantSorts` and `EnsureRequirements` so that `EnsureRequirements` will be invoked before `RemoveRedundantSorts` to avoid IllegalArgumentException when instantiating PartitioningCollection. `RemoveRedundantSorts` rule uses SparkPlan's `outputPartitioning` to check whether a sort node is redundant. Currently, it is added before `EnsureRequirements`. Since `PartitioningCollection` requires left and right partitioning to have the same number of partitions, which is not necessarily true before applying `EnsureRequirements`, the rule can fail with the following exception: ``` IllegalArgumentException: requirement failed: PartitioningCollection requires all of its partitionings have the same numPartitions. ``` No Unit test Closes #30373 from allisonwang-db/sort-follow-up. Authored-by: allisonwang-db <66282705+allisonwang-db@users.noreply.github.com> Signed-off-by: Wenchen Fan (cherry picked from commit a03c540cf7fe92160caf41ef6d2e2993f667dc59) Signed-off-by: allisonwang-db <66282705+allisonwang-db@users.noreply.github.com> --- .../spark/sql/execution/QueryExecution.scala | 4 ++- .../spark/sql/execution/SparkPlan.scala | 7 ++++- .../execution/RemoveRedundantSortsSuite.scala | 27 +++++++++++++++++++ 3 files changed, 36 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index b92f34680f668..0b9c46936b24f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -97,8 +97,10 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { /** A sequence of rules that will be applied in order to the physical plan before execution. */ protected def preparations: Seq[Rule[SparkPlan]] = Seq( PlanSubqueries(sparkSession), - RemoveRedundantSorts(sparkSession.sessionState.conf), EnsureRequirements(sparkSession.sessionState.conf), + // `RemoveRedundantSorts` needs to be added before `EnsureRequirements` to guarantee the same + // number of partitions when instantiating PartitioningCollection. + RemoveRedundantSorts(sparkSession.sessionState.conf), CollapseCodegenStages(sparkSession.sessionState.conf), ReuseExchange(sparkSession.sessionState.conf), ReuseSubquery(sparkSession.sessionState.conf)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 7646f9613efb3..28addf6025110 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -91,7 +91,12 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ def longMetric(name: String): SQLMetric = metrics(name) // TODO: Move to `DistributedPlan` - /** Specifies how data is partitioned across different nodes in the cluster. */ + /** + * Specifies how data is partitioned across different nodes in the cluster. + * Note this method may fail if it is invoked before `EnsureRequirements` is applied + * since `PartitioningCollection` requires all its partitionings to have + * the same number of partitions. + */ def outputPartitioning: Partitioning = UnknownPartitioning(0) // TODO: WRONG WIDTH! /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala index f7987e293b3f8..b82e5cb77c077 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.{DataFrame, QueryTest} +import org.apache.spark.sql.catalyst.plans.physical.{RangePartitioning, UnknownPartitioning} +import org.apache.spark.sql.execution.joins.SortMergeJoinExec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -99,4 +101,29 @@ class RemoveRedundantSortsSuite } } } + + test("SPARK-33472: shuffled join with different left and right side partition numbers") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + withTempView("t1", "t2") { + spark.range(0, 100, 1, 2).select('id as "key").createOrReplaceTempView("t1") + (0 to 100).toDF("key").createOrReplaceTempView("t2") + + val query = """ + |SELECT t1.key + |FROM t1 JOIN t2 ON t1.key = t2.key + |WHERE t1.key > 10 AND t2.key < 50 + |ORDER BY t1.key ASC + """.stripMargin + + val df = sql(query) + val sparkPlan = df.queryExecution.sparkPlan + val join = sparkPlan.collect { case j: SortMergeJoinExec => j }.head + val leftPartitioning = join.left.outputPartitioning + assert(leftPartitioning.isInstanceOf[RangePartitioning]) + assert(leftPartitioning.numPartitions == 2) + assert(join.right.outputPartitioning == UnknownPartitioning(0)) + checkSorts(query, 3, 3) + } + } + } }