From b38a21ef6146784e4b93ef4ce8c899f1eee14572 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 16 Nov 2015 18:30:26 -0800 Subject: [PATCH 1/9] SPARK-11633 --- .../spark/sql/catalyst/analysis/Analyzer.scala | 3 ++- .../spark/sql/hive/execution/SQLQuerySuite.scala | 16 ++++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 2f4670b55bdba..5a5b71e52dd79 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -425,7 +425,8 @@ class Analyzer( */ j case Some((oldRelation, newRelation)) => - val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) + val attributeRewrites = + AttributeMap(oldRelation.output.zip(newRelation.output).filter(x => x._1 != x._2)) val newRight = right transformUp { case r if r == oldRelation => newRelation } transformUp { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 3427152b2da02..5e00546a74c00 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -51,6 +51,8 @@ case class Order( state: String, month: Int) +case class Individual(F1: Integer, F2: Integer) + case class WindowData( month: Int, area: String, @@ -1479,4 +1481,18 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { |FROM (SELECT '{"f1": "value1", "f2": 12}' json, 'hello' as str) test """.stripMargin), Row("value1", "12", 3.14, "hello")) } + + test ("SPARK-11633: HiveContext throws TreeNode Exception : Failed to Copy Node") { + val rdd1 = sparkContext.parallelize(Seq( Individual(1,3), Individual(2,1))) + val df = hiveContext.createDataFrame(rdd1) + df.registerTempTable("foo") + val df2 = sql("select f1, F2 as F2 from foo") + df2.registerTempTable("foo2") + df2.registerTempTable("foo3") + + checkAnswer(sql( + """ + SELECT a.F1 FROM foo2 a INNER JOIN foo3 b ON a.F2=b.F2 + """.stripMargin), Row(2) :: Row(1) :: Nil) + } } From 0546772f151f83d6d3cf4d000cbe341f52545007 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 20 Nov 2015 10:56:45 -0800 Subject: [PATCH 2/9] converge --- .../spark/sql/catalyst/analysis/Analyzer.scala | 3 +-- .../spark/sql/hive/execution/SQLQuerySuite.scala | 15 --------------- 2 files changed, 1 insertion(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 7c9512fbd00aa..47962ebe6ef82 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -417,8 +417,7 @@ class Analyzer( */ j case Some((oldRelation, newRelation)) => - val attributeRewrites = - AttributeMap(oldRelation.output.zip(newRelation.output).filter(x => x._1 != x._2)) + val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) val newRight = right transformUp { case r if r == oldRelation => newRelation } transformUp { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 5e00546a74c00..61d9dcd37572b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -51,8 +51,6 @@ case class Order( state: String, month: Int) -case class Individual(F1: Integer, F2: Integer) - case class WindowData( month: Int, area: String, @@ -1481,18 +1479,5 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { |FROM (SELECT '{"f1": "value1", "f2": 12}' json, 'hello' as str) test """.stripMargin), Row("value1", "12", 3.14, "hello")) } - - test ("SPARK-11633: HiveContext throws TreeNode Exception : Failed to Copy Node") { - val rdd1 = sparkContext.parallelize(Seq( Individual(1,3), Individual(2,1))) - val df = hiveContext.createDataFrame(rdd1) - df.registerTempTable("foo") - val df2 = sql("select f1, F2 as F2 from foo") - df2.registerTempTable("foo2") - df2.registerTempTable("foo3") - - checkAnswer(sql( - """ - SELECT a.F1 FROM foo2 a INNER JOIN foo3 b ON a.F2=b.F2 - """.stripMargin), Row(2) :: Row(1) :: Nil) } } From b37a64f13956b6ddd0e38ddfd9fe1caee611f1a8 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 20 Nov 2015 10:58:37 -0800 Subject: [PATCH 3/9] converge --- .../org/apache/spark/sql/hive/execution/SQLQuerySuite.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 61d9dcd37572b..3427152b2da02 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1479,5 +1479,4 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { |FROM (SELECT '{"f1": "value1", "f2": 12}' json, 'hello' as str) test """.stripMargin), Row("value1", "12", 3.14, "hello")) } - } } From f0616711a5721ae65e1db1954453a5f862aaa8c6 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 21 Nov 2015 17:06:06 -0800 Subject: [PATCH 4/9] Support Persist/Cache and Unpersist in DataSet APIs --- .../scala/org/apache/spark/sql/Dataset.scala | 47 +++++++++++++++++-- .../spark/sql/execution/CacheManager.scala | 22 +++++---- .../spark/sql/execution/Queryable.scala | 2 + .../org/apache/spark/sql/DatasetSuite.scala | 22 +++++++++ 4 files changed, 79 insertions(+), 14 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index bdcdc5d47cbae..11d30f277fc34 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql +import org.apache.spark.storage.StorageLevel + import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD import org.apache.spark.api.java.function._ -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ @@ -462,7 +463,7 @@ class Dataset[T] private[sql]( * combined. * * Note that, this function is not a typical set union operation, in that it does not eliminate - * duplicate items. As such, it is analagous to `UNION ALL` in SQL. + * duplicate items. As such, it is analogous to `UNION ALL` in SQL. * @since 1.6.0 */ def union(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Union) @@ -511,7 +512,6 @@ class Dataset[T] private[sql]( case _ => Alias(CreateStruct(rightOutput), "_2")() } - implicit val tuple2Encoder: Encoder[(T, U)] = ExpressionEncoder.tuple(this.unresolvedTEncoder, other.unresolvedTEncoder) withPlan[(T, U)](other) { (left, right) => @@ -580,11 +580,50 @@ class Dataset[T] private[sql]( */ def takeAsList(num: Int): java.util.List[T] = java.util.Arrays.asList(take(num) : _*) + + /* ******* * + * Cache * + * ******* */ + + /** + * @since 1.6.0 + */ + def persist(): this.type = { + sqlContext.cacheManager.cacheQuery(this) + this + } + + /** + * @since 1.6.0 + */ + def cache(): this.type = persist() + + /** + * @since 1.6.0 + */ + def persist(newLevel: StorageLevel): this.type = { + sqlContext.cacheManager.cacheQuery(this, None, newLevel) + this + } + + /** + * @since 1.6.0 + */ + def unpersist(blocking: Boolean): this.type = { + sqlContext.cacheManager.tryUncacheQuery(this, blocking) + this + } + + /** + * @since 1.3.0 + */ + def unpersist(): this.type = unpersist(blocking = false) + /* ******************** * * Internal Functions * * ******************** */ - private[sql] def logicalPlan = queryExecution.analyzed + private[sql] def logicalPlan : LogicalPlan = queryExecution.analyzed private[sql] def withPlan(f: LogicalPlan => LogicalPlan): Dataset[T] = new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan)), tEncoder) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 293fcfe96e677..56b7a1f5d5e5c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import java.util.concurrent.locks.ReentrantReadWriteLock import org.apache.spark.Logging -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.storage.StorageLevel @@ -75,12 +75,12 @@ private[sql] class CacheManager extends Logging { } /** - * Caches the data produced by the logical representation of the given [[DataFrame]]. Unlike - * `RDD.cache()`, the default storage level is set to be `MEMORY_AND_DISK` because recomputing - * the in-memory columnar representation of the underlying table is expensive. + * Caches the data produced by the logical representation of the given [[DataFrame]]/[[Dataset]]. + * Unlike `RDD.cache()`, the default storage level is set to be `MEMORY_AND_DISK` because + * recomputing the in-memory columnar representation of the underlying table is expensive. */ private[sql] def cacheQuery( - query: DataFrame, + query: Queryable, tableName: Option[String] = None, storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = writeLock { val planToCache = query.queryExecution.analyzed @@ -100,7 +100,7 @@ private[sql] class CacheManager extends Logging { } } - /** Removes the data for the given [[DataFrame]] from the cache */ + /** Removes the data for the given [[DataFrame]]/[[Dataset]] from the cache */ private[sql] def uncacheQuery(query: DataFrame, blocking: Boolean = true): Unit = writeLock { val planToCache = query.queryExecution.analyzed val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan)) @@ -109,9 +109,11 @@ private[sql] class CacheManager extends Logging { cachedData.remove(dataIndex) } - /** Tries to remove the data for the given [[DataFrame]] from the cache if it's cached */ + /** Tries to remove the data for the given [[DataFrame]]/[[Dataset]] from the cache + * if it's cached + */ private[sql] def tryUncacheQuery( - query: DataFrame, + query: Queryable, blocking: Boolean = true): Boolean = writeLock { val planToCache = query.queryExecution.analyzed val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan)) @@ -123,8 +125,8 @@ private[sql] class CacheManager extends Logging { found } - /** Optionally returns cached data for the given [[DataFrame]] */ - private[sql] def lookupCachedData(query: DataFrame): Option[CachedData] = readLock { + /** Optionally returns cached data for the given [[DataFrame]]/[[Dataset]] */ + private[sql] def lookupCachedData(query: Queryable): Option[CachedData] = readLock { lookupCachedData(query.queryExecution.analyzed) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala index e86a52c149a2f..22ff40ec254af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.types.StructType import scala.util.control.NonFatal @@ -27,6 +28,7 @@ private[sql] trait Queryable { def schema: StructType def queryExecution: QueryExecution def sqlContext: SQLContext + private[sql] def logicalPlan: LogicalPlan override def toString: String = { try { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 89d964aa3e469..a9eeda677222e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql import java.io.{ObjectInput, ObjectOutput, Externalizable} +import org.apache.spark.sql.execution.columnar.InMemoryRelation + import scala.language.postfixOps import org.apache.spark.sql.functions._ @@ -213,6 +215,26 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } + test("persist and unpersist") { + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS().select(expr("_2 + 1").as[Int]) + val cached = ds.cache() + // count triggers the caching action. It should not throw. + cached.count() + // Make sure, the Dataset is indeed cached. + assert(sqlContext.cacheManager.lookupCachedData(cached).nonEmpty) + assertResult(1, "InMemoryRelation not found, testData should have been cached") { + cached.queryExecution.withCachedData.collect { + case r: InMemoryRelation => r + }.size + } + // Check result. + checkAnswer( + cached, + 2, 3, 4) + // Drop the cache. + cached.unpersist() + } + test("groupBy function, keys") { val ds = Seq(("a", 1), ("b", 1)).toDS() val grouped = ds.groupBy(v => (1, v._2)) From c135e1fefd9b621e8c073d71913cc3f45af7b308 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 21 Nov 2015 17:20:19 -0800 Subject: [PATCH 5/9] update the @since --- sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 11d30f277fc34..0675f04d213df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -615,7 +615,7 @@ class Dataset[T] private[sql]( } /** - * @since 1.3.0 + * @since 1.6.0 */ def unpersist(): this.type = unpersist(blocking = false) From 25177779bab74f02b8fe722db7644f7cfcf631e3 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sun, 22 Nov 2015 20:23:59 -0800 Subject: [PATCH 6/9] adding more test cases --- .../org/apache/spark/sql/DatasetSuite.scala | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 93601b60a4278..f8670802776dd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -235,6 +235,36 @@ class DatasetSuite extends QueryTest with SharedSQLContext { cached.unpersist() } + test("persist and then rebind right encoder when join 2 datasets") { + val ds1 = Seq("1", "2").toDS().as("a") + val ds2 = Seq(2, 3).toDS().as("b") + + ds1.persist() + ds2.persist() + + val joined = ds1.joinWith(ds2, $"a.value" === $"b.value") + checkAnswer(joined, ("2", 2)) + + ds1.unpersist() + ds2.unpersist() + } + + test("persist and then groupBy columns asKey, map") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + ds.persist() + + val grouped = ds.groupBy($"_1").keyAs[String] + val agged = grouped.mapGroup { case (g, iter) => (g, iter.map(_._2).sum) } + agged.persist() + + checkAnswer( + agged.filter(_._1 == "b"), + ("b", 3)) + + ds.unpersist() + agged.unpersist() + } + test("groupBy function, keys") { val ds = Seq(("a", 1), ("b", 1)).toDS() val grouped = ds.groupBy(v => (1, v._2)) From 683fa6f223f3dfda7575e5c2a066b96ac9f3552f Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 24 Nov 2015 23:58:09 -0800 Subject: [PATCH 7/9] resolved all the comments --- .../org/apache/spark/sql/DataFrame.scala | 9 ++ .../scala/org/apache/spark/sql/Dataset.scala | 22 +++-- .../org/apache/spark/sql/SQLContext.scala | 9 ++ .../spark/sql/execution/CacheManager.scala | 15 ++-- .../spark/sql/execution/Queryable.scala | 2 - .../org/apache/spark/sql/CacheSuite.scala | 89 +++++++++++++++++++ .../org/apache/spark/sql/DatasetSuite.scala | 52 ----------- .../org/apache/spark/sql/QueryTest.scala | 5 +- 8 files changed, 131 insertions(+), 72 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/CacheSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 5eca1db9525ec..00706ec656092 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1554,6 +1554,7 @@ class DataFrame private[sql]( def distinct(): DataFrame = dropDuplicates() /** + * Persist this [[DataFrame]] with the default storage level (`MEMORY_AND_DISK`). * @group basic * @since 1.3.0 */ @@ -1563,12 +1564,17 @@ class DataFrame private[sql]( } /** + * Persist this [[DataFrame]] with the default storage level (`MEMORY_AND_DISK`). * @group basic * @since 1.3.0 */ def cache(): this.type = persist() /** + * Persist this [[DataFrame]] with the given storage level. + * @param newLevel One of: `MEMORY_ONLY`, `MEMORY_AND_DISK`, `MEMORY_ONLY_SER`, + * `MEMORY_AND_DISK_SER`, `DISK_ONLY`, `MEMORY_ONLY_2`, + * `MEMORY_AND_DISK_2`, etc. * @group basic * @since 1.3.0 */ @@ -1578,6 +1584,8 @@ class DataFrame private[sql]( } /** + * Mark the [[DataFrame]] as non-persistent, and remove all blocks for it from memory and disk. + * @param blocking Whether to block until all blocks are deleted. * @group basic * @since 1.3.0 */ @@ -1587,6 +1595,7 @@ class DataFrame private[sql]( } /** + * Mark the [[DataFrame]] as non-persistent, and remove all blocks for it from memory and disk. * @group basic * @since 1.3.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 8ce5f714206d4..5c067a67db2cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -17,14 +17,13 @@ package org.apache.spark.sql -import org.apache.spark.storage.StorageLevel import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental -import org.apache.spark.rdd.RDD -import org.apache.spark.api.java.function._ +import org.apache.spark.api.java.function._ +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias @@ -32,6 +31,7 @@ import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{Queryable, QueryExecution} import org.apache.spark.sql.types.StructType +import org.apache.spark.storage.StorageLevel /** * :: Experimental :: @@ -601,11 +601,8 @@ class Dataset[T] private[sql]( def takeAsList(num: Int): java.util.List[T] = java.util.Arrays.asList(take(num) : _*) - /* ******* * - * Cache * - * ******* */ - /** + * Persist this [[Dataset]] with the default storage level (`MEMORY_AND_DISK`). * @since 1.6.0 */ def persist(): this.type = { @@ -614,11 +611,17 @@ class Dataset[T] private[sql]( } /** + * Persist this [[Dataset]] with the default storage level (`MEMORY_AND_DISK`). * @since 1.6.0 */ def cache(): this.type = persist() /** + * Persist this [[Dataset]] with the given storage level. + * @param newLevel One of: `MEMORY_ONLY`, `MEMORY_AND_DISK`, `MEMORY_ONLY_SER`, + * `MEMORY_AND_DISK_SER`, `DISK_ONLY`, `MEMORY_ONLY_2`, + * `MEMORY_AND_DISK_2`, etc. + * @group basic * @since 1.6.0 */ def persist(newLevel: StorageLevel): this.type = { @@ -627,6 +630,8 @@ class Dataset[T] private[sql]( } /** + * Mark the [[Dataset]] as non-persistent, and remove all blocks for it from memory and disk. + * @param blocking Whether to block until all blocks are deleted. * @since 1.6.0 */ def unpersist(blocking: Boolean): this.type = { @@ -635,6 +640,7 @@ class Dataset[T] private[sql]( } /** + * Mark the [[Dataset]] as non-persistent, and remove all blocks for it from memory and disk. * @since 1.6.0 */ def unpersist(): this.type = unpersist(blocking = false) @@ -643,7 +649,7 @@ class Dataset[T] private[sql]( * Internal Functions * * ******************** */ - private[sql] def logicalPlan : LogicalPlan = queryExecution.analyzed + private[sql] def logicalPlan: LogicalPlan = queryExecution.analyzed private[sql] def withPlan(f: LogicalPlan => LogicalPlan): Dataset[T] = new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan)), tEncoder) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 46bf544fd885f..22229f6630353 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -338,6 +338,15 @@ class SQLContext private[sql]( cacheManager.lookupCachedData(table(tableName)).nonEmpty } + /** + * Returns true if the [[Queryable]] is currently cached in-memory. + * @group cachemgmt + * @since 1.3.0 + */ + def isCached(qName: Queryable): Boolean = { + cacheManager.lookupCachedData(qName).nonEmpty + } + /** * Caches the specified table in-memory. * @group cachemgmt diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 56b7a1f5d5e5c..50f6562815c21 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution import java.util.concurrent.locks.ReentrantReadWriteLock import org.apache.spark.Logging -import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.storage.StorageLevel @@ -75,7 +74,7 @@ private[sql] class CacheManager extends Logging { } /** - * Caches the data produced by the logical representation of the given [[DataFrame]]/[[Dataset]]. + * Caches the data produced by the logical representation of the given [[Queryable]]. * Unlike `RDD.cache()`, the default storage level is set to be `MEMORY_AND_DISK` because * recomputing the in-memory columnar representation of the underlying table is expensive. */ @@ -95,13 +94,13 @@ private[sql] class CacheManager extends Logging { sqlContext.conf.useCompression, sqlContext.conf.columnBatchSize, storageLevel, - sqlContext.executePlan(query.logicalPlan).executedPlan, + sqlContext.executePlan(planToCache).executedPlan, tableName)) } } - /** Removes the data for the given [[DataFrame]]/[[Dataset]] from the cache */ - private[sql] def uncacheQuery(query: DataFrame, blocking: Boolean = true): Unit = writeLock { + /** Removes the data for the given [[Queryable]] from the cache */ + private[sql] def uncacheQuery(query: Queryable, blocking: Boolean = true): Unit = writeLock { val planToCache = query.queryExecution.analyzed val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan)) require(dataIndex >= 0, s"Table $query is not cached.") @@ -109,7 +108,7 @@ private[sql] class CacheManager extends Logging { cachedData.remove(dataIndex) } - /** Tries to remove the data for the given [[DataFrame]]/[[Dataset]] from the cache + /** Tries to remove the data for the given [[Queryable]] from the cache * if it's cached */ private[sql] def tryUncacheQuery( @@ -125,12 +124,12 @@ private[sql] class CacheManager extends Logging { found } - /** Optionally returns cached data for the given [[DataFrame]]/[[Dataset]] */ + /** Optionally returns cached data for the given [[Queryable]] */ private[sql] def lookupCachedData(query: Queryable): Option[CachedData] = readLock { lookupCachedData(query.queryExecution.analyzed) } - /** Optionally returns cached data for the given LogicalPlan. */ + /** Optionally returns cached data for the given [[LogicalPlan]]. */ private[sql] def lookupCachedData(plan: LogicalPlan): Option[CachedData] = readLock { cachedData.find(cd => plan.sameResult(cd.plan)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala index bf158f1eda888..321e2c783537f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.types.StructType import scala.util.control.NonFatal @@ -28,7 +27,6 @@ private[sql] trait Queryable { def schema: StructType def queryExecution: QueryExecution def sqlContext: SQLContext - private[sql] def logicalPlan: LogicalPlan override def toString: String = { try { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CacheSuite.scala new file mode 100644 index 0000000000000..cb78885659da6 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/CacheSuite.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.language.postfixOps + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext + + +class CacheSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("persist and unpersist") { + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS().select(expr("_2 + 1").as[Int]) + val cached = ds.cache() + // count triggers the caching action. It should not throw. + cached.count() + // Make sure, the Dataset is indeed cached. + assertCached(cached) + // Check result. + checkAnswer( + cached, + 2, 3, 4) + // Drop the cache. + cached.unpersist() + assert(!sqlContext.isCached(cached), "The Dataset should not be cached.") + } + + test("persist and then rebind right encoder when join 2 datasets") { + val ds1 = Seq("1", "2").toDS().as("a") + val ds2 = Seq(2, 3).toDS().as("b") + + ds1.persist() + assertCached(ds1) + ds2.persist() + assertCached(ds2) + + val joined = ds1.joinWith(ds2, $"a.value" === $"b.value") + checkAnswer(joined, ("2", 2)) + + ds1.unpersist() + assert(!sqlContext.isCached(ds1), "The Dataset ds1 should not be cached.") + ds2.unpersist() + assert(!sqlContext.isCached(ds2), "The Dataset ds2 should not be cached.") + } + + test("persist and then groupBy columns asKey, map") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + val grouped = ds.groupBy($"_1").keyAs[String] + val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) } + agged.persist() + + checkAnswer( + agged.filter(_._1 == "b"), + ("b", 3)) + + ds.unpersist() + assert(!sqlContext.isCached(ds), "The Dataset ds should not be cached.") + agged.unpersist() + assert(!sqlContext.isCached(agged), "The Dataset agged should not be cached.") + } + + ignore("persist and then map/filter with lambda functions") { + val f = (i: Int) => i + 1 + + val ds = Seq(1, 2, 3).toDS() + val mapped = ds.map(f) + mapped.cache() + + val mapped2 = ds.map(f) + assertCached(mapped2) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 7107f64f48769..c253fdbb8c99e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql import java.io.{ObjectInput, ObjectOutput, Externalizable} -import org.apache.spark.sql.execution.columnar.InMemoryRelation - import scala.language.postfixOps import org.apache.spark.sql.functions._ @@ -230,56 +228,6 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } - test("persist and unpersist") { - val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS().select(expr("_2 + 1").as[Int]) - val cached = ds.cache() - // count triggers the caching action. It should not throw. - cached.count() - // Make sure, the Dataset is indeed cached. - assert(sqlContext.cacheManager.lookupCachedData(cached).nonEmpty) - assertResult(1, "InMemoryRelation not found, testData should have been cached") { - cached.queryExecution.withCachedData.collect { - case r: InMemoryRelation => r - }.size - } - // Check result. - checkAnswer( - cached, - 2, 3, 4) - // Drop the cache. - cached.unpersist() - } - - test("persist and then rebind right encoder when join 2 datasets") { - val ds1 = Seq("1", "2").toDS().as("a") - val ds2 = Seq(2, 3).toDS().as("b") - - ds1.persist() - ds2.persist() - - val joined = ds1.joinWith(ds2, $"a.value" === $"b.value") - checkAnswer(joined, ("2", 2)) - - ds1.unpersist() - ds2.unpersist() - } - - test("persist and then groupBy columns asKey, map") { - val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - ds.persist() - - val grouped = ds.groupBy($"_1").keyAs[String] - val agged = grouped.mapGroup { case (g, iter) => (g, iter.map(_._2).sum) } - agged.persist() - - checkAnswer( - agged.filter(_._1 == "b"), - ("b", 3)) - - ds.unpersist() - agged.unpersist() - } - test("groupBy function, keys") { val ds = Seq(("a", 1), ("b", 1)).toDS() val grouped = ds.groupBy(v => (1, v._2)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 6ea1fe4ccfd89..e12458c31992e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -24,6 +24,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.columnar.InMemoryRelation +import org.apache.spark.sql.execution.Queryable abstract class QueryTest extends PlanTest { @@ -163,9 +164,9 @@ abstract class QueryTest extends PlanTest { } /** - * Asserts that a given [[DataFrame]] will be executed using the given number of cached results. + * Asserts that a given [[Queryable]] will be executed using the given number of cached results. */ - def assertCached(query: DataFrame, numCachedTables: Int = 1): Unit = { + def assertCached(query: Queryable, numCachedTables: Int = 1): Unit = { val planWithCaching = query.queryExecution.withCachedData val cachedData = planWithCaching collect { case cached: InMemoryRelation => cached From b9518ee73d23ad6908e841bcbe93c45424f5df6e Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 30 Nov 2015 17:43:44 -0800 Subject: [PATCH 8/9] updated the codes based on the review comments from Michale Armbrust. --- .../org/apache/spark/sql/SQLContext.scala | 2 +- .../org/apache/spark/sql/CacheSuite.scala | 89 ------------------- 2 files changed, 1 insertion(+), 90 deletions(-) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/CacheSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 2e89881b8800b..4e26250868374 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -343,7 +343,7 @@ class SQLContext private[sql]( * @group cachemgmt * @since 1.3.0 */ - def isCached(qName: Queryable): Boolean = { + private[sql] def isCached(qName: Queryable): Boolean = { cacheManager.lookupCachedData(qName).nonEmpty } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CacheSuite.scala deleted file mode 100644 index cb78885659da6..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/CacheSuite.scala +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import scala.language.postfixOps - -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.SharedSQLContext - - -class CacheSuite extends QueryTest with SharedSQLContext { - import testImplicits._ - - test("persist and unpersist") { - val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS().select(expr("_2 + 1").as[Int]) - val cached = ds.cache() - // count triggers the caching action. It should not throw. - cached.count() - // Make sure, the Dataset is indeed cached. - assertCached(cached) - // Check result. - checkAnswer( - cached, - 2, 3, 4) - // Drop the cache. - cached.unpersist() - assert(!sqlContext.isCached(cached), "The Dataset should not be cached.") - } - - test("persist and then rebind right encoder when join 2 datasets") { - val ds1 = Seq("1", "2").toDS().as("a") - val ds2 = Seq(2, 3).toDS().as("b") - - ds1.persist() - assertCached(ds1) - ds2.persist() - assertCached(ds2) - - val joined = ds1.joinWith(ds2, $"a.value" === $"b.value") - checkAnswer(joined, ("2", 2)) - - ds1.unpersist() - assert(!sqlContext.isCached(ds1), "The Dataset ds1 should not be cached.") - ds2.unpersist() - assert(!sqlContext.isCached(ds2), "The Dataset ds2 should not be cached.") - } - - test("persist and then groupBy columns asKey, map") { - val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - val grouped = ds.groupBy($"_1").keyAs[String] - val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) } - agged.persist() - - checkAnswer( - agged.filter(_._1 == "b"), - ("b", 3)) - - ds.unpersist() - assert(!sqlContext.isCached(ds), "The Dataset ds should not be cached.") - agged.unpersist() - assert(!sqlContext.isCached(agged), "The Dataset agged should not be cached.") - } - - ignore("persist and then map/filter with lambda functions") { - val f = (i: Int) => i + 1 - - val ds = Seq(1, 2, 3).toDS() - val mapped = ds.map(f) - mapped.cache() - - val mapped2 = ds.map(f) - assertCached(mapped2) - } -} From b8d287a84f354bcf59e7c258f912c55bff3da32a Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 30 Nov 2015 17:45:05 -0800 Subject: [PATCH 9/9] Changed the name from CacheSuite.scala to DatasetCacheSuite.scala --- .../apache/spark/sql/DatasetCacheSuite.scala | 80 +++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala new file mode 100644 index 0000000000000..3a283a4e1f610 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.language.postfixOps + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext + + +class DatasetCacheSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("persist and unpersist") { + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS().select(expr("_2 + 1").as[Int]) + val cached = ds.cache() + // count triggers the caching action. It should not throw. + cached.count() + // Make sure, the Dataset is indeed cached. + assertCached(cached) + // Check result. + checkAnswer( + cached, + 2, 3, 4) + // Drop the cache. + cached.unpersist() + assert(!sqlContext.isCached(cached), "The Dataset should not be cached.") + } + + test("persist and then rebind right encoder when join 2 datasets") { + val ds1 = Seq("1", "2").toDS().as("a") + val ds2 = Seq(2, 3).toDS().as("b") + + ds1.persist() + assertCached(ds1) + ds2.persist() + assertCached(ds2) + + val joined = ds1.joinWith(ds2, $"a.value" === $"b.value") + checkAnswer(joined, ("2", 2)) + assertCached(joined, 2) + + ds1.unpersist() + assert(!sqlContext.isCached(ds1), "The Dataset ds1 should not be cached.") + ds2.unpersist() + assert(!sqlContext.isCached(ds2), "The Dataset ds2 should not be cached.") + } + + test("persist and then groupBy columns asKey, map") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + val grouped = ds.groupBy($"_1").keyAs[String] + val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) } + agged.persist() + + checkAnswer( + agged.filter(_._1 == "b"), + ("b", 3)) + assertCached(agged.filter(_._1 == "b")) + + ds.unpersist() + assert(!sqlContext.isCached(ds), "The Dataset ds should not be cached.") + agged.unpersist() + assert(!sqlContext.isCached(agged), "The Dataset agged should not be cached.") + } +}