From 93c93da7592327cde653815bf8a8b4c9c0193931 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 17 Nov 2016 08:45:39 +0000 Subject: [PATCH 1/3] Fix memory leak. --- .../apache/spark/sql/execution/SparkPlan.scala | 6 ++++++ .../org/apache/spark/sql/DatasetSuite.scala | 16 +++++++++++++++- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index cadab37a449aa..150609256b656 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -234,6 +234,12 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ row.writeToStream(out, buffer) count += 1 } + // If iterator has more elements, we should consume them all. Otherwise under wholestage + // codegen, as we release resources after consuming all elements (e.g., HashAggregate), it + // will cause problems such as memory leak. + while (iter.hasNext) { + iter.next() + } out.writeInt(-1) out.flush() out.close() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 81fa8cbf22384..350ffb50231f6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{DataTypes, IntegerType, StringType, StructField, StructType} class DatasetSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -1051,6 +1051,20 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkDataset(dsDouble, arrayDouble) checkDataset(dsString, arrayString) } + + test("SPARK-18487: Consume all elements for show/take to avoid memory leak") { + val rng = new scala.util.Random(42) + val data = sparkContext.parallelize(Seq.tabulate(100) { i => + Row(Array.fill(10)(rng.nextInt(10))) + }) + val schema = StructType(Seq( + StructField("arr", DataTypes.createArrayType(DataTypes.IntegerType)) + )) + val df = spark.createDataFrame(data, schema) + val exploded = df.select(struct(col("*")).as("star"), explode(col("arr")).as("a")) + val joined = exploded.join(exploded, "a").drop("a").distinct() + joined.show() + } } case class Generic[T](id: T, value: Double) From 2f304f07efafb37eeae45a00ebd671fcedceb97a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 18 Nov 2016 02:06:43 +0000 Subject: [PATCH 2/3] Use task completion listener instead. --- .../apache/spark/sql/execution/SparkPlan.scala | 6 ------ .../execution/aggregate/HashAggregateExec.scala | 17 ++++++++++++++--- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 150609256b656..cadab37a449aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -234,12 +234,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ row.writeToStream(out, buffer) count += 1 } - // If iterator has more elements, we should consume them all. Otherwise under wholestage - // codegen, as we release resources after consuming all elements (e.g., HashAggregate), it - // will cause problems such as memory leak. - while (iter.hasNext) { - iter.next() - } out.writeInt(-1) out.flush() out.close() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 4529ed067e565..0eee4f31196a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -295,6 +295,11 @@ case class HashAggregateExec( private var hashMapTerm: String = _ private var sorterTerm: String = _ + // Becasue Dataset.show/take methods will end of iteraton before reaching the end of all rows, + // we may not release resources then and cause memory leak. So we need to hold the reference + // of the hash map if it is created and release the resources after task completion. + private var hashMapToRelease: UnsafeFixedWidthAggregationMap = _ + /** * This is called by generated Java class, should be public. */ @@ -302,17 +307,23 @@ case class HashAggregateExec( // create initialized aggregate buffer val initExpr = declFunctions.flatMap(f => f.initialValues) val initialBuffer = UnsafeProjection.create(initExpr)(EmptyRow) + val context = TaskContext.get() // create hashMap - new UnsafeFixedWidthAggregationMap( + hashMapToRelease = new UnsafeFixedWidthAggregationMap( initialBuffer, bufferSchema, groupingKeySchema, - TaskContext.get().taskMemoryManager(), + context.taskMemoryManager(), 1024 * 16, // initial capacity - TaskContext.get().taskMemoryManager().pageSizeBytes, + context.taskMemoryManager().pageSizeBytes, false // disable tracking of performance metrics ) + + // Release the resources of the hash map when the end of task. + context.addTaskCompletionListener(_ => hashMapToRelease.free()) + + hashMapToRelease } def getTaskMemoryManager(): TaskMemoryManager = { From fa1d1fd8736ba8297de30121acd392ef270b3704 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 18 Nov 2016 03:00:38 +0000 Subject: [PATCH 3/3] Update test title. --- sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 350ffb50231f6..37732ff6271ff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1052,7 +1052,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkDataset(dsString, arrayString) } - test("SPARK-18487: Consume all elements for show/take to avoid memory leak") { + test("SPARK-18487: Add completion listener to HashAggregate to avoid memory leak") { val rng = new scala.util.Random(42) val data = sparkContext.parallelize(Seq.tabulate(100) { i => Row(Array.fill(10)(rng.nextInt(10)))