From 9ca9bb8ba4a5830d829ab637273b3fc9d359ad85 Mon Sep 17 00:00:00 2001 From: allisonwang-db <66282705+allisonwang-db@users.noreply.github.com> Date: Wed, 28 Oct 2020 05:51:47 +0000 Subject: [PATCH 1/2] [SPARK-33183][SQL] Fix Optimizer rule EliminateSorts and add a physical rule to remove redundant sorts This PR aims to fix a correctness bug in the optimizer rule `EliminateSorts`. It also adds a new physical rule to remove redundant sorts that cannot be eliminated in the Optimizer rule after the bugfix. A global sort should not be eliminated even if its child is ordered since we don't know if its child ordering is global or local. For example, in the following scenario, the first sort shouldn't be removed because it has a stronger guarantee than the second sort even if the sort orders are the same for both sorts. ``` Sort(orders, global = True, ...) Sort(orders, global = False, ...) ``` Since there is no straightforward way to identify whether a node's output ordering is local or global, we should not remove a global sort even if its child is already ordered. Yes Unit tests Closes #30093 from allisonwang-db/fix-sort. Authored-by: allisonwang-db <66282705+allisonwang-db@users.noreply.github.com> Signed-off-by: Wenchen Fan (cherry picked from commit 9fb45361fd00b046e04748e1a1c8add3fa09f01c) Signed-off-by: allisonwang-db <66282705+allisonwang-db@users.noreply.github.com> --- .../sql/catalyst/optimizer/Optimizer.scala | 16 +- .../apache/spark/sql/internal/SQLConf.scala | 7 + .../optimizer/EliminateSortsSuite.scala | 102 +++++++++++-- .../spark/sql/execution/QueryExecution.scala | 1 + .../sql/execution/RemoveRedundantSorts.scala | 46 ++++++ .../adaptive/AdaptiveSparkPlanExec.scala | 2 + .../spark/sql/execution/PlannerSuite.scala | 13 -- .../execution/RemoveRedundantSortsSuite.scala | 144 ++++++++++++++++++ 8 files changed, 303 insertions(+), 28 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/RemoveRedundantSorts.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index d93c4a5bc459a..4629cbbfa7f51 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -971,7 +971,7 @@ object CombineFilters extends Rule[LogicalPlan] with PredicateHelper { /** * Removes Sort operation. This can happen: * 1) if the sort order is empty or the sort order does not have any reference - * 2) if the child is already sorted + * 2) if the Sort operator is a local sort and the child is already sorted * 3) if there is another Sort operator separated by 0...n Project/Filter operators * 4) if the Sort operator is within Join separated by 0...n Project/Filter operators only, * and the Join conditions is deterministic @@ -979,12 +979,18 @@ object CombineFilters extends Rule[LogicalPlan] with PredicateHelper { * and the aggregate function is order irrelevant */ object EliminateSorts extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan transform applyLocally + + private val applyLocally: PartialFunction[LogicalPlan, LogicalPlan] = { case s @ Sort(orders, _, child) if orders.isEmpty || orders.exists(_.child.foldable) => val newOrders = orders.filterNot(_.child.foldable) - if (newOrders.isEmpty) child else s.copy(order = newOrders) - case Sort(orders, true, child) if SortOrder.orderingSatisfies(child.outputOrdering, orders) => - child + if (newOrders.isEmpty) { + applyLocally.lift(child).getOrElse(child) + } else { + s.copy(order = newOrders) + } + case Sort(orders, false, child) if SortOrder.orderingSatisfies(child.outputOrdering, orders) => + applyLocally.lift(child).getOrElse(child) case s @ Sort(_, _, child) => s.copy(child = recursiveRemoveSort(child)) case j @ Join(originLeft, originRight, _, cond, _) if cond.forall(_.deterministic) => j.copy(left = recursiveRemoveSort(originLeft), right = recursiveRemoveSort(originRight)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index c74888340ac12..be011903f50dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1201,6 +1201,13 @@ object SQLConf { .booleanConf .createWithDefault(true) + val REMOVE_REDUNDANT_SORTS_ENABLED = buildConf("spark.sql.execution.removeRedundantSorts") + .internal() + .doc("Whether to remove redundant physical sort node") + .version("3.1.0") + .booleanConf + .createWithDefault(true) + val STATE_STORE_PROVIDER_CLASS = buildConf("spark.sql.streaming.stateStore.providerClass") .internal() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala index e2b599a7c090c..e34f141412ca6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala @@ -97,12 +97,34 @@ class EliminateSortsSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("remove redundant order by") { + test("SPARK-33183: remove consecutive no-op sorts") { + val plan = testRelation.orderBy().orderBy().orderBy() + val optimized = Optimize.execute(plan.analyze) + val correctAnswer = testRelation.analyze + comparePlans(optimized, correctAnswer) + } + + test("SPARK-33183: remove redundant sort by") { val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst) - val unnecessaryReordered = orderedPlan.limit(2).select('a).orderBy('a.asc, 'b.desc_nullsFirst) + val unnecessaryReordered = orderedPlan.limit(2).select('a).sortBy('a.asc, 'b.desc_nullsFirst) val optimized = Optimize.execute(unnecessaryReordered.analyze) val correctAnswer = orderedPlan.limit(2).select('a).analyze - comparePlans(Optimize.execute(optimized), correctAnswer) + comparePlans(optimized, correctAnswer) + } + + test("SPARK-33183: remove all redundant local sorts") { + val orderedPlan = testRelation.sortBy('a.asc).orderBy('a.asc).sortBy('a.asc) + val optimized = Optimize.execute(orderedPlan.analyze) + val correctAnswer = testRelation.orderBy('a.asc).analyze + comparePlans(optimized, correctAnswer) + } + + test("SPARK-33183: should not remove global sort") { + val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst) + val reordered = orderedPlan.limit(2).select('a).orderBy('a.asc, 'b.desc_nullsFirst) + val optimized = Optimize.execute(reordered.analyze) + val correctAnswer = reordered.analyze + comparePlans(optimized, correctAnswer) } test("do not remove sort if the order is different") { @@ -113,22 +135,39 @@ class EliminateSortsSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("filters don't affect order") { + test("SPARK-33183: remove top level local sort with filter operators") { val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc) - val filteredAndReordered = orderedPlan.where('a > Literal(10)).orderBy('a.asc, 'b.desc) + val filteredAndReordered = orderedPlan.where('a > Literal(10)).sortBy('a.asc, 'b.desc) val optimized = Optimize.execute(filteredAndReordered.analyze) val correctAnswer = orderedPlan.where('a > Literal(10)).analyze comparePlans(optimized, correctAnswer) } - test("limits don't affect order") { + test("SPARK-33183: keep top level global sort with filter operators") { + val projectPlan = testRelation.select('a, 'b) + val orderedPlan = projectPlan.orderBy('a.asc, 'b.desc) + val filteredAndReordered = orderedPlan.where('a > Literal(10)).orderBy('a.asc, 'b.desc) + val optimized = Optimize.execute(filteredAndReordered.analyze) + val correctAnswer = projectPlan.where('a > Literal(10)).orderBy('a.asc, 'b.desc).analyze + comparePlans(optimized, correctAnswer) + } + + test("SPARK-33183: limits should not affect order for local sort") { val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc) - val filteredAndReordered = orderedPlan.limit(Literal(10)).orderBy('a.asc, 'b.desc) + val filteredAndReordered = orderedPlan.limit(Literal(10)).sortBy('a.asc, 'b.desc) val optimized = Optimize.execute(filteredAndReordered.analyze) val correctAnswer = orderedPlan.limit(Literal(10)).analyze comparePlans(optimized, correctAnswer) } + test("SPARK-33183: should not remove global sort with limit operators") { + val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc) + val filteredAndReordered = orderedPlan.limit(Literal(10)).orderBy('a.asc, 'b.desc) + val optimized = Optimize.execute(filteredAndReordered.analyze) + val correctAnswer = filteredAndReordered.analyze + comparePlans(optimized, correctAnswer) + } + test("different sorts are not simplified if limit is in between") { val orderedPlan = testRelation.select('a, 'b).orderBy('b.desc).limit(Literal(10)) .orderBy('a.asc) @@ -137,11 +176,11 @@ class EliminateSortsSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("range is already sorted") { + test("SPARK-33183: should not remove global sort with range operator") { val inputPlan = Range(1L, 1000L, 1, 10) val orderedPlan = inputPlan.orderBy('id.asc) val optimized = Optimize.execute(orderedPlan.analyze) - val correctAnswer = inputPlan.analyze + val correctAnswer = orderedPlan.analyze comparePlans(optimized, correctAnswer) val reversedPlan = inputPlan.orderBy('id.desc) @@ -152,10 +191,18 @@ class EliminateSortsSuite extends PlanTest { val negativeStepInputPlan = Range(10L, 1L, -1, 10) val negativeStepOrderedPlan = negativeStepInputPlan.orderBy('id.desc) val negativeStepOptimized = Optimize.execute(negativeStepOrderedPlan.analyze) - val negativeStepCorrectAnswer = negativeStepInputPlan.analyze + val negativeStepCorrectAnswer = negativeStepOrderedPlan.analyze comparePlans(negativeStepOptimized, negativeStepCorrectAnswer) } + test("SPARK-33183: remove local sort with range operator") { + val inputPlan = Range(1L, 1000L, 1, 10) + val orderedPlan = inputPlan.sortBy('id.asc) + val optimized = Optimize.execute(orderedPlan.analyze) + val correctAnswer = inputPlan.analyze + comparePlans(optimized, correctAnswer) + } + test("sort should not be removed when there is a node which doesn't guarantee any order") { val orderedPlan = testRelation.select('a, 'b) val groupedAndResorted = orderedPlan.groupBy('a)(sum('a)).orderBy('a.asc) @@ -319,4 +366,39 @@ class EliminateSortsSuite extends PlanTest { val correctAnswer = PushDownOptimizer.execute(noOrderByPlan.analyze) comparePlans(optimized, correctAnswer) } + + test("SPARK-33183: remove consecutive global sorts with the same ordering") { + Seq( + (testRelation.orderBy('a.asc).orderBy('a.asc), testRelation.orderBy('a.asc)), + (testRelation.orderBy('a.asc, 'b.desc).orderBy('a.asc), testRelation.orderBy('a.asc)) + ).foreach { case (ordered, answer) => + val optimized = Optimize.execute(ordered.analyze) + comparePlans(optimized, answer.analyze) + } + } + + test("SPARK-33183: remove consecutive local sorts with the same ordering") { + val orderedPlan = testRelation.sortBy('a.asc).sortBy('a.asc).sortBy('a.asc) + val optimized = Optimize.execute(orderedPlan.analyze) + val correctAnswer = testRelation.sortBy('a.asc).analyze + comparePlans(optimized, correctAnswer) + } + + test("SPARK-33183: remove consecutive local sorts with different ordering") { + val orderedPlan = testRelation.sortBy('b.asc).sortBy('a.desc).sortBy('a.asc) + val optimized = Optimize.execute(orderedPlan.analyze) + val correctAnswer = testRelation.sortBy('a.asc).analyze + comparePlans(optimized, correctAnswer) + } + + test("SPARK-33183: should keep global sort when child is a local sort with the same ordering") { + val correctAnswer = testRelation.orderBy('a.asc).analyze + Seq( + testRelation.sortBy('a.asc).orderBy('a.asc), + testRelation.orderBy('a.asc).sortBy('a.asc).orderBy('a.asc) + ).foreach { ordered => + val optimized = Optimize.execute(ordered.analyze) + comparePlans(optimized, correctAnswer) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 50c9c31029d90..574a67fecf936 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -297,6 +297,7 @@ object QueryExecution { Seq( PlanDynamicPruningFilters(sparkSession), PlanSubqueries(sparkSession), + RemoveRedundantSorts(sparkSession.sessionState.conf), EnsureRequirements(sparkSession.sessionState.conf), ApplyColumnarRulesAndInsertTransitions(sparkSession.sessionState.conf, sparkSession.sessionState.columnarRules), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/RemoveRedundantSorts.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/RemoveRedundantSorts.scala new file mode 100644 index 0000000000000..71f36c8c1dd5a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/RemoveRedundantSorts.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.expressions.SortOrder +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf + +/** + * Remove redundant SortExec node from the spark plan. A sort node is redundant when + * its child satisfies both its sort orders and its required child distribution. Note + * this rule differs from the Optimizer rule EliminateSorts in that this rule also checks + * if the child satisfies the required distribution so that it is safe to remove not only a + * local sort but also a global sort when its child already satisfies required sort orders. + */ +case class RemoveRedundantSorts(conf: SQLConf) extends Rule[SparkPlan] { + def apply(plan: SparkPlan): SparkPlan = { + if (!conf.getConf(SQLConf.REMOVE_REDUNDANT_SORTS_ENABLED)) { + plan + } else { + removeSorts(plan) + } + } + + private def removeSorts(plan: SparkPlan): SparkPlan = plan transform { + case s @ SortExec(orders, _, child, _) + if SortOrder.orderingSatisfies(child.outputOrdering, orders) && + child.outputPartitioning.satisfies(s.requiredChildDistribution.head) => + child + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index 8b59b1236782e..4e73f06f36f1d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -83,12 +83,14 @@ case class AdaptiveSparkPlanExec( ) } + @transient private val removeRedundantSorts = RemoveRedundantSorts(conf) @transient private val ensureRequirements = EnsureRequirements(conf) // A list of physical plan rules to be applied before creation of query stages. The physical // plan should reach a final status of query stages (i.e., no more addition or removal of // Exchange nodes) after running these rules. private def queryStagePreparationRules: Seq[Rule[SparkPlan]] = Seq( + removeRedundantSorts, ensureRequirements ) ++ context.session.sessionState.queryStagePrepRules diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 9c8e44323ce98..b7be0f1320e61 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -234,19 +234,6 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { } } - test("SPARK-23375: Cached sorted data doesn't need to be re-sorted") { - val query = testData.select('key, 'value).sort('key.desc).cache() - assert(query.queryExecution.optimizedPlan.isInstanceOf[InMemoryRelation]) - val resorted = query.sort('key.desc) - assert(resorted.queryExecution.optimizedPlan.collect { case s: Sort => s}.isEmpty) - assert(resorted.select('key).collect().map(_.getInt(0)).toSeq == - (1 to 100).reverse) - // with a different order, the sort is needed - val sortedAsc = query.sort('key) - assert(sortedAsc.queryExecution.optimizedPlan.collect { case s: Sort => s}.size == 1) - assert(sortedAsc.select('key).collect().map(_.getInt(0)).toSeq == (1 to 100)) - } - test("PartitioningCollection") { withTempView("normal", "small", "tiny") { testData.createOrReplaceTempView("normal") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala new file mode 100644 index 0000000000000..54c5a33441900 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.{DataFrame, QueryTest} +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession + + +abstract class RemoveRedundantSortsSuiteBase + extends QueryTest + with SharedSparkSession + with AdaptiveSparkPlanHelper { + import testImplicits._ + + private def checkNumSorts(df: DataFrame, count: Int): Unit = { + val plan = df.queryExecution.executedPlan + assert(collectWithSubqueries(plan) { case s: SortExec => s }.length == count) + } + + private def checkSorts(query: String, enabledCount: Int, disabledCount: Int): Unit = { + withSQLConf(SQLConf.REMOVE_REDUNDANT_SORTS_ENABLED.key -> "true") { + val df = sql(query) + checkNumSorts(df, enabledCount) + val result = df.collect() + withSQLConf(SQLConf.REMOVE_REDUNDANT_SORTS_ENABLED.key -> "false") { + val df = sql(query) + checkNumSorts(df, disabledCount) + checkAnswer(df, result) + } + } + } + + test("remove redundant sorts with limit") { + withTempView("t") { + spark.range(100).select('id as "key").createOrReplaceTempView("t") + val query = + """ + |SELECT key FROM + | (SELECT key FROM t WHERE key > 10 ORDER BY key DESC LIMIT 10) + |ORDER BY key DESC + |""".stripMargin + checkSorts(query, 0, 1) + } + } + + test("remove redundant sorts with broadcast hash join") { + withTempView("t1", "t2") { + spark.range(1000).select('id as "key").createOrReplaceTempView("t1") + spark.range(1000).select('id as "key").createOrReplaceTempView("t2") + + val queryTemplate = """ + |SELECT /*+ BROADCAST(%s) */ t1.key FROM + | (SELECT key FROM t1 WHERE key > 10 ORDER BY key DESC LIMIT 10) t1 + |JOIN + | (SELECT key FROM t2 WHERE key > 50 ORDER BY key DESC LIMIT 100) t2 + |ON t1.key = t2.key + |ORDER BY %s + """.stripMargin + + // No sort should be removed since the stream side (t2) order DESC + // does not satisfy the required sort order ASC. + val buildLeftOrderByRightAsc = queryTemplate.format("t1", "t2.key ASC") + checkSorts(buildLeftOrderByRightAsc, 1, 1) + + // The top sort node should be removed since the stream side (t2) order DESC already + // satisfies the required sort order DESC. + val buildLeftOrderByRightDesc = queryTemplate.format("t1", "t2.key DESC") + checkSorts(buildLeftOrderByRightDesc, 0, 1) + + // No sort should be removed since the sort ordering from broadcast-hash join is based + // on the stream side (t2) and the required sort order is from t1. + val buildLeftOrderByLeftDesc = queryTemplate.format("t1", "t1.key DESC") + checkSorts(buildLeftOrderByLeftDesc, 1, 1) + + // The top sort node should be removed since the stream side (t1) order DESC already + // satisfies the required sort order DESC. + val buildRightOrderByLeftDesc = queryTemplate.format("t2", "t1.key DESC") + checkSorts(buildRightOrderByLeftDesc, 0, 1) + } + } + + test("remove redundant sorts with sort merge join") { + withTempView("t1", "t2") { + spark.range(1000).select('id as "key").createOrReplaceTempView("t1") + spark.range(1000).select('id as "key").createOrReplaceTempView("t2") + val query = """ + |SELECT /*+ MERGE(t1) */ t1.key FROM + | (SELECT key FROM t1 WHERE key > 10 ORDER BY key DESC LIMIT 10) t1 + |JOIN + | (SELECT key FROM t2 WHERE key > 50 ORDER BY key DESC LIMIT 100) t2 + |ON t1.key = t2.key + |ORDER BY t1.key + """.stripMargin + + val queryAsc = query + " ASC" + checkSorts(queryAsc, 2, 3) + + // The top level sort should not be removed since the child output ordering is ASC and + // the required ordering is DESC. + val queryDesc = query + " DESC" + checkSorts(queryDesc, 3, 3) + } + } + + test("cached sorted data doesn't need to be re-sorted") { + withSQLConf(SQLConf.REMOVE_REDUNDANT_SORTS_ENABLED.key -> "true") { + val df = spark.range(1000).select('id as "key").sort('key.desc).cache() + val resorted = df.sort('key.desc) + val sortedAsc = df.sort('key.asc) + checkNumSorts(df, 0) + checkNumSorts(resorted, 0) + checkNumSorts(sortedAsc, 1) + val result = resorted.collect() + withSQLConf(SQLConf.REMOVE_REDUNDANT_SORTS_ENABLED.key -> "false") { + val resorted = df.sort('key.desc) + checkNumSorts(resorted, 1) + checkAnswer(resorted, result) + } + } + } +} + +class RemoveRedundantSortsSuite extends RemoveRedundantSortsSuiteBase + with DisableAdaptiveExecutionSuite + +class RemoveRedundantSortsSuiteAE extends RemoveRedundantSortsSuiteBase + with EnableAdaptiveExecutionSuite From 9526deea2f24208dbd6ebd0ed29e8ddaadd84604 Mon Sep 17 00:00:00 2001 From: allisonwang-db <66282705+allisonwang-db@users.noreply.github.com> Date: Thu, 29 Oct 2020 19:49:13 -0700 Subject: [PATCH 2/2] fix test --- .../apache/spark/sql/internal/SQLConf.scala | 2 +- .../execution/RemoveRedundantSortsSuite.scala | 36 ------------------- 2 files changed, 1 insertion(+), 37 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index be011903f50dc..23d1d70bdb22c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1204,7 +1204,7 @@ object SQLConf { val REMOVE_REDUNDANT_SORTS_ENABLED = buildConf("spark.sql.execution.removeRedundantSorts") .internal() .doc("Whether to remove redundant physical sort node") - .version("3.1.0") + .version("2.4.8") .booleanConf .createWithDefault(true) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala index 54c5a33441900..1978d22cb87bb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala @@ -60,42 +60,6 @@ abstract class RemoveRedundantSortsSuiteBase } } - test("remove redundant sorts with broadcast hash join") { - withTempView("t1", "t2") { - spark.range(1000).select('id as "key").createOrReplaceTempView("t1") - spark.range(1000).select('id as "key").createOrReplaceTempView("t2") - - val queryTemplate = """ - |SELECT /*+ BROADCAST(%s) */ t1.key FROM - | (SELECT key FROM t1 WHERE key > 10 ORDER BY key DESC LIMIT 10) t1 - |JOIN - | (SELECT key FROM t2 WHERE key > 50 ORDER BY key DESC LIMIT 100) t2 - |ON t1.key = t2.key - |ORDER BY %s - """.stripMargin - - // No sort should be removed since the stream side (t2) order DESC - // does not satisfy the required sort order ASC. - val buildLeftOrderByRightAsc = queryTemplate.format("t1", "t2.key ASC") - checkSorts(buildLeftOrderByRightAsc, 1, 1) - - // The top sort node should be removed since the stream side (t2) order DESC already - // satisfies the required sort order DESC. - val buildLeftOrderByRightDesc = queryTemplate.format("t1", "t2.key DESC") - checkSorts(buildLeftOrderByRightDesc, 0, 1) - - // No sort should be removed since the sort ordering from broadcast-hash join is based - // on the stream side (t2) and the required sort order is from t1. - val buildLeftOrderByLeftDesc = queryTemplate.format("t1", "t1.key DESC") - checkSorts(buildLeftOrderByLeftDesc, 1, 1) - - // The top sort node should be removed since the stream side (t1) order DESC already - // satisfies the required sort order DESC. - val buildRightOrderByLeftDesc = queryTemplate.format("t2", "t1.key DESC") - checkSorts(buildRightOrderByLeftDesc, 0, 1) - } - } - test("remove redundant sorts with sort merge join") { withTempView("t1", "t2") { spark.range(1000).select('id as "key").createOrReplaceTempView("t1")