From b932d2f3a6741a8ef052cbd8087f4b0836c617d6 Mon Sep 17 00:00:00 2001 From: donnyzone Date: Fri, 11 Aug 2017 21:00:00 +0800 Subject: [PATCH 1/5] spark-19471 --- .../aggregate/AggregationIterator.scala | 4 ++++ .../aggregate/HashAggregateExec.scala | 3 ++- .../aggregate/ObjectAggregationIterator.scala | 2 ++ .../aggregate/ObjectHashAggregateExec.scala | 3 ++- .../aggregate/SortAggregateExec.scala | 3 ++- .../SortBasedAggregationIterator.scala | 2 ++ .../TungstenAggregationIterator.scala | 2 ++ .../spark/sql/DataFrameFunctionsSuite.scala | 22 +++++++++++++++++++ 8 files changed, 38 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala index 7c11fdb9792e8..28d2055aef22d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ * is used to generate result. */ abstract class AggregationIterator( + partIndex: Int, groupingExpressions: Seq[NamedExpression], inputAttributes: Seq[Attribute], aggregateExpressions: Seq[AggregateExpression], @@ -229,6 +230,7 @@ abstract class AggregationIterator( allImperativeAggregateFunctions(i).eval(currentBuffer)) i += 1 } + resultProjection.initialize(partIndex) resultProjection(joinedRow(currentGroupingKey, aggregateResult)) } } else if (modes.contains(Partial) || modes.contains(PartialMerge)) { @@ -251,12 +253,14 @@ abstract class AggregationIterator( typedImperativeAggregates(i).serializeAggregateBufferInPlace(currentBuffer) i += 1 } + resultProjection.initialize(partIndex) resultProjection(joinedRow(currentGroupingKey, currentBuffer)) } } else { // Grouping-only: we only output values based on grouping expressions. val resultProjection = UnsafeProjection.create(resultExpressions, groupingAttributes) (currentGroupingKey: UnsafeRow, currentBuffer: InternalRow) => { + resultProjection.initialize(partIndex) resultProjection(currentGroupingKey) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 56f61c30c4a38..80ea458687865 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -96,7 +96,7 @@ case class HashAggregateExec( val spillSize = longMetric("spillSize") val avgHashProbe = longMetric("avgHashProbe") - child.execute().mapPartitions { iter => + child.execute().mapPartitionsWithIndex { (partIndex, iter) => val hasInput = iter.hasNext if (!hasInput && groupingExpressions.nonEmpty) { @@ -106,6 +106,7 @@ case class HashAggregateExec( } else { val aggregationIterator = new TungstenAggregationIterator( + partIndex, groupingExpressions, aggregateExpressions, aggregateAttributes, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala index eef2c4e843f35..c68dbc73f0447 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala @@ -31,6 +31,7 @@ import org.apache.spark.unsafe.KVIterator import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter class ObjectAggregationIterator( + partIndex: Int, outputAttributes: Seq[Attribute], groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression], @@ -43,6 +44,7 @@ class ObjectAggregationIterator( fallbackCountThreshold: Int, numOutputRows: SQLMetric) extends AggregationIterator( + partIndex, groupingExpressions, originalInputAttributes, aggregateExpressions, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala index b53521b1b6ba2..6316e06a8f34e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -98,7 +98,7 @@ case class ObjectHashAggregateExec( val numOutputRows = longMetric("numOutputRows") val fallbackCountThreshold = sqlContext.conf.objectAggSortBasedFallbackThreshold - child.execute().mapPartitionsInternal { iter => + child.execute().mapPartitionsWithIndexInternal { (partIndex, iter) => val hasInput = iter.hasNext if (!hasInput && groupingExpressions.nonEmpty) { // This is a grouped aggregate and the input kvIterator is empty, @@ -107,6 +107,7 @@ case class ObjectHashAggregateExec( } else { val aggregationIterator = new ObjectAggregationIterator( + partIndex, child.output, groupingExpressions, aggregateExpressions, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index be3198b8e7d82..a43235790834e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -74,7 +74,7 @@ case class SortAggregateExec( protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { val numOutputRows = longMetric("numOutputRows") - child.execute().mapPartitionsInternal { iter => + child.execute().mapPartitionsWithIndexInternal { (partIndex, iter) => // Because the constructor of an aggregation iterator will read at least the first row, // we need to get the value of iter.hasNext first. val hasInput = iter.hasNext @@ -84,6 +84,7 @@ case class SortAggregateExec( Iterator[UnsafeRow]() } else { val outputIter = new SortBasedAggregationIterator( + partIndex, groupingExpressions, child.output, iter, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index a5a444b160c63..492b0f2da77cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.execution.metric.SQLMetric * sorted by values of [[groupingExpressions]]. */ class SortBasedAggregationIterator( + partIndex: Int, groupingExpressions: Seq[NamedExpression], valueAttributes: Seq[Attribute], inputIterator: Iterator[InternalRow], @@ -37,6 +38,7 @@ class SortBasedAggregationIterator( newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection, numOutputRows: SQLMetric) extends AggregationIterator( + partIndex, groupingExpressions, valueAttributes, aggregateExpressions, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index cfa930607360c..c6bf3fd6ab83e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -77,6 +77,7 @@ import org.apache.spark.unsafe.KVIterator * the iterator containing input [[UnsafeRow]]s. */ class TungstenAggregationIterator( + partIndex: Int, groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression], aggregateAttributes: Seq[Attribute], @@ -91,6 +92,7 @@ class TungstenAggregationIterator( spillSize: SQLMetric, avgHashProbe: SQLMetric) extends AggregationIterator( + partIndex, groupingExpressions, originalInputAttributes, aggregateExpressions, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 0681b9cbeb1d8..d5e8963199e1d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -449,6 +449,28 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ).foreach(assertValuesDoNotChangeAfterCoalesceOrUnion(_)) } + private def assertNoExceptions(c: Column): Unit = { + for ((wholeStage, useObjectHashAgg) <- Seq((true, false), (false, false), (false, true))) { + withSQLConf( + (SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStage.toString), + (SQLConf.USE_OBJECT_HASH_AGG.key, useObjectHashAgg.toString)) { + val df = Seq(("1", 1), ("1", 2), ("2", 3), ("2", 4)).toDF("x", "y") + // HashAggregate + df.groupBy("x").agg(c, sum("y")).collect() + // ObjectHashAggregate and SortAggregate + df.groupBy("x").agg(c, collect_list("y")).collect() + } + } + } + + test("SPARK-19471: AggregationIterator does not initialize the generated result projection" + + " before using it") { + Seq( + monotonically_increasing_id(), spark_partition_id(), + rand(Random.nextLong()), randn(Random.nextLong()) + ).foreach(assertNoExceptions(_)) + } + test("SPARK-21281 use string types by default if array and map have no argument") { val ds = spark.range(1) var expectedSchema = new StructType() From bb29b8f7e8e8be436ea028acfbe16ae6f4977169 Mon Sep 17 00:00:00 2001 From: donnyzone Date: Sun, 13 Aug 2017 20:46:05 +0800 Subject: [PATCH 2/5] add comment for param --- .../sql/execution/aggregate/TungstenAggregationIterator.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index c6bf3fd6ab83e..756eeb642e2d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -60,6 +60,8 @@ import org.apache.spark.unsafe.KVIterator * - Part 8: A utility function used to generate a result when there is no * input and there is no grouping expression. * + * @param partIndex + * index of the partition * @param groupingExpressions * expressions for grouping keys * @param aggregateExpressions From 5239ebb5843315430d5c942dc53e09fb09d6c1c8 Mon Sep 17 00:00:00 2001 From: donnyzone Date: Mon, 14 Aug 2017 08:32:21 +0800 Subject: [PATCH 3/5] check the plan for unit test --- .../spark/sql/DataFrameFunctionsSuite.scala | 33 ++++++++++++++++--- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index d5e8963199e1d..9ad83a2ef4689 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -24,6 +24,8 @@ import scala.util.Random import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.execution.WholeStageCodegenExec +import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -450,15 +452,36 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } private def assertNoExceptions(c: Column): Unit = { - for ((wholeStage, useObjectHashAgg) <- Seq((true, false), (false, false), (false, true))) { + for ((wholeStage, useObjectHashAgg) <- + Seq((true, true), (true, false), (false, true), (false, false))) { withSQLConf( (SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStage.toString), (SQLConf.USE_OBJECT_HASH_AGG.key, useObjectHashAgg.toString)) { + val df = Seq(("1", 1), ("1", 2), ("2", 3), ("2", 4)).toDF("x", "y") - // HashAggregate - df.groupBy("x").agg(c, sum("y")).collect() - // ObjectHashAggregate and SortAggregate - df.groupBy("x").agg(c, collect_list("y")).collect() + + // HashAggregate test case + val hashAggDF = df.groupBy("x").agg(c, sum("y")) + val hashAggPlan = hashAggDF.queryExecution.executedPlan + if (wholeStage) { + assert(hashAggPlan.find(p => + p.isInstanceOf[WholeStageCodegenExec] && + p.asInstanceOf[WholeStageCodegenExec].child + .isInstanceOf[HashAggregateExec]).isDefined) + } else { + assert(hashAggPlan.isInstanceOf[HashAggregateExec]) + } + hashAggDF.collect() + + // ObjectHashAggregate and SortAggregate test cases + val objHashOrSort_AggDF = df.groupBy("x").agg(c, collect_list("y")) + val objHashOrSort_Plan = objHashOrSort_AggDF.queryExecution.executedPlan + if (useObjectHashAgg) { + assert(objHashOrSort_Plan.isInstanceOf[ObjectHashAggregateExec]) + } else { + assert(objHashOrSort_Plan.isInstanceOf[SortAggregateExec]) + } + objHashOrSort_AggDF.collect() } } } From f55b161d653567bcdda96e5337ee4fab1dca0165 Mon Sep 17 00:00:00 2001 From: donnyzone Date: Mon, 14 Aug 2017 14:59:57 +0800 Subject: [PATCH 4/5] code refractor --- .../aggregate/AggregationIterator.scala | 6 +++--- .../spark/sql/DataFrameFunctionsSuite.scala | 20 +++++++++---------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala index 28d2055aef22d..98c4a51299958 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala @@ -218,6 +218,7 @@ abstract class AggregationIterator( val resultProjection = UnsafeProjection.create(resultExpressions, groupingAttributes ++ aggregateAttributes) + resultProjection.initialize(partIndex) (currentGroupingKey: UnsafeRow, currentBuffer: InternalRow) => { // Generate results for all expression-based aggregate functions. @@ -230,13 +231,13 @@ abstract class AggregationIterator( allImperativeAggregateFunctions(i).eval(currentBuffer)) i += 1 } - resultProjection.initialize(partIndex) resultProjection(joinedRow(currentGroupingKey, aggregateResult)) } } else if (modes.contains(Partial) || modes.contains(PartialMerge)) { val resultProjection = UnsafeProjection.create( groupingAttributes ++ bufferAttributes, groupingAttributes ++ bufferAttributes) + resultProjection.initialize(partIndex) // TypedImperativeAggregate stores generic object in aggregation buffer, and requires // calling serialization before shuffling. See [[TypedImperativeAggregate]] for more info. @@ -253,14 +254,13 @@ abstract class AggregationIterator( typedImperativeAggregates(i).serializeAggregateBufferInPlace(currentBuffer) i += 1 } - resultProjection.initialize(partIndex) resultProjection(joinedRow(currentGroupingKey, currentBuffer)) } } else { // Grouping-only: we only output values based on grouping expressions. val resultProjection = UnsafeProjection.create(resultExpressions, groupingAttributes) + resultProjection.initialize(partIndex) (currentGroupingKey: UnsafeRow, currentBuffer: InternalRow) => { - resultProjection.initialize(partIndex) resultProjection(currentGroupingKey) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 9ad83a2ef4689..19f313dcbfbcb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -464,24 +464,24 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { val hashAggDF = df.groupBy("x").agg(c, sum("y")) val hashAggPlan = hashAggDF.queryExecution.executedPlan if (wholeStage) { - assert(hashAggPlan.find(p => - p.isInstanceOf[WholeStageCodegenExec] && - p.asInstanceOf[WholeStageCodegenExec].child - .isInstanceOf[HashAggregateExec]).isDefined) + assert(hashAggPlan.find { + case WholeStageCodegenExec(_: HashAggregateExec) => true + case _ => false + }.isDefined) } else { assert(hashAggPlan.isInstanceOf[HashAggregateExec]) } hashAggDF.collect() // ObjectHashAggregate and SortAggregate test cases - val objHashOrSort_AggDF = df.groupBy("x").agg(c, collect_list("y")) - val objHashOrSort_Plan = objHashOrSort_AggDF.queryExecution.executedPlan + val objHashAggOrSortAggDF = df.groupBy("x").agg(c, collect_list("y")) + val objHashAggOrSortAggPlan = objHashAggOrSortAggDF.queryExecution.executedPlan if (useObjectHashAgg) { - assert(objHashOrSort_Plan.isInstanceOf[ObjectHashAggregateExec]) + assert(objHashAggOrSortAggPlan.isInstanceOf[ObjectHashAggregateExec]) } else { - assert(objHashOrSort_Plan.isInstanceOf[SortAggregateExec]) + assert(objHashAggOrSortAggPlan.isInstanceOf[SortAggregateExec]) } - objHashOrSort_AggDF.collect() + objHashAggOrSortAggDF.collect() } } } @@ -491,7 +491,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Seq( monotonically_increasing_id(), spark_partition_id(), rand(Random.nextLong()), randn(Random.nextLong()) - ).foreach(assertNoExceptions(_)) + ).foreach(assertNoExceptions) } test("SPARK-21281 use string types by default if array and map have no argument") { From d58ffaa434337ae19f4b1f59524c84943ff7934f Mon Sep 17 00:00:00 2001 From: donnyzone Date: Mon, 14 Aug 2017 17:50:22 +0800 Subject: [PATCH 5/5] test comment --- .../scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 19f313dcbfbcb..fdb9f1d1e0e94 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -473,7 +473,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } hashAggDF.collect() - // ObjectHashAggregate and SortAggregate test cases + // ObjectHashAggregate and SortAggregate test case val objHashAggOrSortAggDF = df.groupBy("x").agg(c, collect_list("y")) val objHashAggOrSortAggPlan = objHashAggOrSortAggDF.queryExecution.executedPlan if (useObjectHashAgg) {